Skip to content

Commit 9c5ec2f

Browse files
authored
Fieldsplit: replace empty Forms with ZeroBaseForm (#3947)
* Restricted Cofunction RHS * Fix BCs on Cofunction * LinearSolver: check function spaces * assemble(form, zero_bc_nodes=True) as default * Fix FunctionAssignBlock * Allow Cofunction.assign take in constants * remove BaseFormAssembler test * only supply relevant kwargs to OneFormAssembler * Only interpolate the residual, not every cofunction in the RHS * Fix tests * Fix adjoint utils * More robust test for (unrestricted) Cofunction RHS * Replace empty Jacobians with ZeroBaseForm * Do not split off-diagonal blocks if we only want the diagonal * Zero-simplify slate Tensors * set bcs directly on diagonal Cofunction * ImplicitMatrixContext: handle empty action * Only extract constants referenced in the kernel * Adjoint: only skip expand_derivatives if necessary * EquationBC: do not reconstruct empty Forms * lower degree for EquationBC tests * Update .github/workflows/build.yml
1 parent ad9fe2c commit 9c5ec2f

27 files changed

+310
-320
lines changed

demos/netgen/netgen_mesh.py.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order
380380

381381
bc = DirichletBC(V, 0.0, [1])
382382
A = assemble(a, bcs=bc)
383-
b = assemble(l)
384-
bc.apply(b)
383+
b = assemble(l, bcs=bc)
385384
solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})
386385

387386
VTKFile("output/Sphere.pvd").write(sol)

firedrake/adjoint_utils/blocks/dirichlet_bc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
5151
adj_output = None
5252
for adj_input in adj_inputs:
5353
if isconstant(c):
54-
adj_value = firedrake.Function(self.parent_space.dual())
54+
adj_value = firedrake.Function(self.parent_space)
5555
adj_input.apply(adj_value)
5656
if self.function_space != self.parent_space:
5757
vec = extract_bc_subvector(
@@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
8888
# you can even use the Function outside its domain.
8989
# For now we will just assume the FunctionSpace is the same for
9090
# the BC and the Function.
91-
adj_value = firedrake.Function(self.parent_space.dual())
91+
adj_value = firedrake.Function(self.parent_space)
9292
adj_input.apply(adj_value)
9393
r = extract_bc_subvector(
9494
adj_value, c.function_space(), bc
95-
)
95+
).riesz_representation("l2")
9696
if adj_output is None:
9797
adj_output = r
9898
else:

firedrake/adjoint_utils/blocks/function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
7979
)
8080
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
8181
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
82+
diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2")
8283
adj_output = firedrake.Function(
8384
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
8485
)

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):
197197

198198
def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
199199
dJdu_copy = dJdu.copy()
200-
kwargs = self.assemble_kwargs.copy()
201200
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
202201
bcs = self._homogenize_bcs()
203-
kwargs["bcs"] = bcs
204-
dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
202+
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)
205203

206204
for bc in bcs:
207-
bc.apply(dJdu)
205+
bc.zero(dJdu)
208206

209207
adj_sol = firedrake.Function(self.function_space)
210208
firedrake.solve(
@@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
219217
return adj_sol, adj_sol_bdy
220218

221219
def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
222-
adj_sol_bdy = firedrake.Function(
223-
self.function_space.dual(), dJdu.dat - firedrake.assemble(
224-
firedrake.action(dFdu_adj_form, adj_sol)).dat)
225-
return adj_sol_bdy
220+
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
221+
return adj_sol_bdy.riesz_representation("l2")
226222

227223
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
228224
prepared=None):
@@ -264,8 +260,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
264260
return dFdm
265261

266262
dFdm = -firedrake.derivative(F_form, c_rep, trial_function)
267-
dFdm = firedrake.adjoint(dFdm)
268-
dFdm = dFdm * adj_sol
263+
if isinstance(dFdm, ufl.Form):
264+
dFdm = firedrake.adjoint(dFdm)
265+
dFdm = firedrake.action(dFdm, adj_sol)
266+
else:
267+
dFdm = dFdm(adj_sol)
269268
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
270269
return dFdm
271270

@@ -654,9 +653,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
654653
def _adjoint_solve(self, dJdu, compute_bdy):
655654
dJdu_copy = dJdu.copy()
656655
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
657-
bcs = self._homogenize_bcs()
658-
for bc in bcs:
659-
bc.apply(dJdu)
656+
for bc in self.bcs:
657+
bc.zero(dJdu)
660658

661659
if (
662660
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
@@ -876,7 +874,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
876874
self.add_dependency(bc, no_duplicates=True)
877875

878876
def apply_mixedmass(self, a):
879-
b = firedrake.Function(self.target_space)
877+
b = firedrake.Function(self.target_space.dual())
880878
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
881879
self.mixed_mass.mult(vsrc, vrhs)
882880
return b

firedrake/adjoint_utils/variational_solver.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import wraps
33
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
44
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
5+
from firedrake.ufl_expr import derivative, adjoint
56
from ufl import replace
67

78

@@ -11,7 +12,6 @@ def _ad_annotate_init(init):
1112
@no_annotations
1213
@wraps(init)
1314
def wrapper(self, *args, **kwargs):
14-
from firedrake import derivative, adjoint, TrialFunction
1515
init(self, *args, **kwargs)
1616
self._ad_F = self.F
1717
self._ad_u = self.u_restrict
@@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
2020
try:
2121
# Some forms (e.g. SLATE tensors) are not currently
2222
# differentiable.
23-
dFdu = derivative(self.F,
24-
self.u_restrict,
25-
TrialFunction(self.u_restrict.function_space()))
26-
self._ad_adj_F = adjoint(dFdu)
23+
dFdu = derivative(self.F, self.u_restrict)
24+
try:
25+
self._ad_adj_F = adjoint(dFdu)
26+
except ValueError:
27+
# Try again without expanding derivatives,
28+
# as dFdu might have been simplied to an empty Form
29+
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
2730
except (TypeError, NotImplementedError):
2831
self._ad_adj_F = None
2932
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}

firedrake/assemble.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
8383
zero_bc_nodes : bool
8484
If `True`, set the boundary condition nodes in the
8585
output tensor to zero rather than to the values prescribed by the
86-
boundary condition. Default is `False`.
86+
boundary condition. Default is `True`.
8787
diagonal : bool
8888
If assembling a matrix is it diagonal?
8989
weight : float
@@ -143,7 +143,6 @@ def get_assembler(form, *args, **kwargs):
143143
144144
"""
145145
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
146-
bcs = kwargs.get('bcs', None)
147146
fc_params = kwargs.get('form_compiler_parameters', None)
148147
if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed:
149148
mat_type = kwargs.get('mat_type', None)
@@ -155,8 +154,13 @@ def get_assembler(form, *args, **kwargs):
155154
if len(form.arguments()) == 0:
156155
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
157156
elif len(form.arguments()) == 1 or diagonal:
158-
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
159-
zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal)
157+
return OneFormAssembler(form, *args,
158+
bcs=kwargs.get("bcs", None),
159+
form_compiler_parameters=fc_params,
160+
needs_zeroing=kwargs.get("needs_zeroing", True),
161+
zero_bc_nodes=kwargs.get("zero_bc_nodes", True),
162+
diagonal=diagonal,
163+
weight=kwargs.get("weight", 1.0))
160164
elif len(form.arguments()) == 2:
161165
return TwoFormAssembler(form, *args, **kwargs)
162166
else:
@@ -308,7 +312,7 @@ def __init__(self,
308312
sub_mat_type=None,
309313
options_prefix=None,
310314
appctx=None,
311-
zero_bc_nodes=False,
315+
zero_bc_nodes=True,
312316
diagonal=False,
313317
weight=1.0,
314318
allocation_integral_types=None):
@@ -381,6 +385,12 @@ def visitor(e, *operands):
381385
visited = {}
382386
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)
383387

388+
# Apply BCs after assembly
389+
rank = len(self._form.arguments())
390+
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
391+
for bc in self._bcs:
392+
bc.zero(result)
393+
384394
if tensor:
385395
BaseFormAssembler.update_tensor(result, tensor)
386396
return tensor
@@ -405,8 +415,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
405415
if rank == 0:
406416
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
407417
elif rank == 1 or (rank == 2 and self._diagonal):
408-
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
409-
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
418+
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
419+
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
410420
elif rank == 2:
411421
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
412422
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
@@ -577,10 +587,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
577587
@staticmethod
578588
def update_tensor(assembled_base_form, tensor):
579589
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
580-
assembled_base_form.dat.copy(tensor.dat)
590+
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
591+
tensor.dat.zero()
592+
else:
593+
assembled_base_form.dat.copy(tensor.dat)
581594
elif isinstance(tensor, matrix.MatrixBase):
582-
# Uses the PETSc copy method.
583-
assembled_base_form.petscmat.copy(tensor.petscmat)
595+
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
596+
tensor.petscmat.zeroEntries()
597+
else:
598+
assembled_base_form.petscmat.copy(tensor.petscmat)
584599
else:
585600
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))
586601

@@ -807,9 +822,9 @@ def restructure_base_form(expr, visited=None):
807822
return ufl.action(expr, ustar)
808823

809824
# -- Case (6) -- #
810-
if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()):
811-
# Return ufl.Sum
812-
return sum([c for c in expr.components()])
825+
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
826+
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
827+
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
813828
return expr
814829

815830
@staticmethod
@@ -1138,7 +1153,7 @@ class OneFormAssembler(ParloopFormAssembler):
11381153
11391154
Parameters
11401155
----------
1141-
form : ufl.Form or slate.TensorBasehe
1156+
form : ufl.Form or slate.TensorBase
11421157
1-form.
11431158
11441159
Notes
@@ -1149,14 +1164,15 @@ class OneFormAssembler(ParloopFormAssembler):
11491164

11501165
@classmethod
11511166
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
1152-
zero_bc_nodes=False, diagonal=False):
1167+
zero_bc_nodes=True, diagonal=False, weight=1.0):
11531168
bcs = solving._extract_bcs(bcs)
1154-
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal
1169+
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight
11551170

11561171
@FormAssembler._skip_if_initialised
11571172
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
1158-
zero_bc_nodes=False, diagonal=False):
1173+
zero_bc_nodes=True, diagonal=False, weight=1.0):
11591174
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
1175+
self._weight = weight
11601176
self._diagonal = diagonal
11611177
self._zero_bc_nodes = zero_bc_nodes
11621178
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
@@ -1185,23 +1201,21 @@ def _apply_bc(self, tensor, bc):
11851201
elif isinstance(bc, EquationBCSplit):
11861202
bc.zero(tensor)
11871203
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
1188-
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
1204+
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
11891205
else:
11901206
raise AssertionError
11911207

11921208
def _apply_dirichlet_bc(self, tensor, bc):
1193-
if not self._zero_bc_nodes:
1194-
tensor_func = tensor.riesz_representation(riesz_map="l2")
1195-
if self._diagonal:
1196-
bc.set(tensor_func, 1)
1197-
else:
1198-
bc.apply(tensor_func)
1199-
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
1209+
if self._diagonal:
1210+
bc.set(tensor, self._weight)
1211+
elif not self._zero_bc_nodes:
1212+
# NOTE this only works if tensor is a Function and not a Cofunction
1213+
bc.apply(tensor)
12001214
else:
12011215
bc.zero(tensor)
12021216

12031217
def _check_tensor(self, tensor):
1204-
if tensor.function_space() != self._form.arguments()[0].function_space():
1218+
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
12051219
raise ValueError("Form's argument does not match provided result tensor")
12061220

12071221
@staticmethod
@@ -2127,14 +2141,13 @@ def iter_active_coefficients(form, kinfo):
21272141

21282142
@staticmethod
21292143
def iter_constants(form, kinfo):
2130-
"""Yield the form constants"""
2144+
"""Yield the form constants referenced in ``kinfo``."""
21312145
if isinstance(form, slate.TensorBase):
2132-
for const in form.constants():
2133-
yield const
2146+
all_constants = form.constants()
21342147
else:
21352148
all_constants = extract_firedrake_constants(form)
2136-
for constant_index in kinfo.constant_numbers:
2137-
yield all_constants[constant_index]
2149+
for constant_index in kinfo.constant_numbers:
2150+
yield all_constants[constant_index]
21382151

21392152
@staticmethod
21402153
def index_function_spaces(form, indices):

firedrake/bcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
634634
return
635635
rank = len(self.f.arguments())
636636
splitter = ExtractSubBlock()
637-
if rank == 1:
638-
form = splitter.split(self.f, argument_indices=(row_field, ))
639-
elif rank == 2:
640-
form = splitter.split(self.f, argument_indices=(row_field, col_field))
637+
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
638+
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
639+
# form is empty, do nothing
640+
return
641641
if u is not None:
642642
form = firedrake.replace(form, {self.u: u})
643643
if action_x is not None:

firedrake/cofunction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,10 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
229229
return self.assign(
230230
assembled_expr, subset=subset,
231231
expr_from_assemble=True)
232-
233-
raise ValueError('Cannot assign %s' % expr)
232+
else:
233+
from firedrake.assign import Assigner
234+
Assigner(self, expr, subset).assign()
235+
return self
234236

235237
def riesz_representation(self, riesz_map='L2', **solver_options):
236238
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.

0 commit comments

Comments
 (0)