Skip to content

Commit bc82efd

Browse files
JHopeCollinspbrubeck
authored andcommitted
Fix subfunction merge block adjoint (#4177)
* comment to explain how we evaluate the adjoint of the FunctionMergeBlock
1 parent b818105 commit bc82efd

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

firedrake/adjoint_utils/blocks/function.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -197,50 +197,71 @@ class SubfunctionBlock(Block):
197197
def __init__(self, func, idx, ad_block_tag=None):
198198
super().__init__(ad_block_tag=ad_block_tag)
199199
self.add_dependency(func)
200-
self.idx = idx
200+
self.sub_idx = idx
201201

202202
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
203203
prepared=None):
204204
eval_adj = firedrake.Cofunction(block_variable.output.function_space().dual())
205205
if type(adj_inputs[0]) is firedrake.Cofunction:
206-
eval_adj.sub(self.idx).assign(adj_inputs[0])
206+
eval_adj.sub(self.sub_idx).assign(adj_inputs[0])
207207
else:
208-
eval_adj.sub(self.idx).assign(adj_inputs[0].function)
208+
eval_adj.sub(self.sub_idx).assign(adj_inputs[0].function)
209209
return eval_adj
210210

211211
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
212212
prepared=None):
213-
return firedrake.Function.sub(tlm_inputs[0], self.idx)
213+
return firedrake.Function.sub(tlm_inputs[0], self.sub_idx)
214214

215215
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
216216
block_variable, idx,
217217
relevant_dependencies, prepared=None):
218218
eval_hessian = firedrake.Cofunction(block_variable.output.function_space().dual())
219-
eval_hessian.sub(self.idx).assign(hessian_inputs[0])
219+
eval_hessian.sub(self.sub_idx).assign(hessian_inputs[0])
220220
return eval_hessian
221221

222222
def recompute_component(self, inputs, block_variable, idx, prepared):
223223
return maybe_disk_checkpoint(
224-
firedrake.Function.sub(inputs[0], self.idx)
224+
firedrake.Function.sub(inputs[0], self.sub_idx)
225225
)
226226

227227
def __str__(self):
228-
return f"{self.get_dependencies()[0]}[{self.idx}]"
228+
return f"{self.get_dependencies()[0]}[{self.sub_idx}]"
229229

230230

231231
class FunctionMergeBlock(Block):
232232
def __init__(self, func, idx, ad_block_tag=None):
233233
super().__init__(ad_block_tag=ad_block_tag)
234234
self.add_dependency(func)
235-
self.idx = idx
235+
self.sub_idx = idx
236236
for output in func._ad_outputs:
237237
self.add_dependency(output)
238238

239239
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
240240
prepared=None):
241+
# The merge block appears whenever a subfunction is the output of a block.
242+
# This means that the subfunction has been modified, so we need to make
243+
# sure that this modification is accounted for when evaluating the adjoint.
244+
#
245+
# When recomputing the merge block, the indexed subfunction in the full
246+
# Function is completely overwritten, meaning that the pre-existing value
247+
# of the subfunction in the full function is ignored.
248+
# The equivalent adjoint operation is to:
249+
# 1. send the subfunction component of the adjoint value back up
250+
# the branch of the tape corresponding to the subfunction
251+
# dependency (idx=0).
252+
# 2. zero out the subfunction component of the adjoint value sent
253+
# back up the full Function branch of the tape (idx=1).
254+
# This means that when the adjoint values of each branch are combined
255+
# after the SubfunctionBlock only the adjoint value from the subfunction
256+
# branch is used.
257+
#
258+
# See https://github.com/firedrakeproject/firedrake/pull/4177 for more
259+
# detail and for diagrams of the tape produced when accessing subfunctions.
260+
241261
if idx == 0:
242-
return adj_inputs[0].subfunctions[self.idx]
262+
return adj_inputs[0].subfunctions[self.sub_idx].copy(deepcopy=True)
243263
else:
264+
adj_inputs[0].subfunctions[self.sub_idx].zero()
244265
return adj_inputs[0]
245266

246267
def evaluate_tlm(self, markings=False):
@@ -253,7 +274,7 @@ def evaluate_tlm(self, markings=False):
253274
fs = output.output.function_space()
254275
f = type(output.output)(fs)
255276
output.add_tlm_output(
256-
type(output.output).assign(f.sub(self.idx), tlm_input)
277+
type(output.output).assign(f.sub(self.sub_idx), tlm_input)
257278
)
258279

259280
def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
@@ -265,12 +286,12 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
265286
sub_func = inputs[0]
266287
parent_in = inputs[1]
267288
parent_out = type(parent_in)(parent_in)
268-
parent_out.sub(self.idx).assign(sub_func)
289+
parent_out.sub(self.sub_idx).assign(sub_func)
269290
return maybe_disk_checkpoint(parent_out)
270291

271292
def __str__(self):
272293
deps = self.get_dependencies()
273-
return f"{deps[1]}[{self.idx}].assign({deps[0]})"
294+
return f"{deps[1]}[{self.sub_idx}].assign({deps[0]})"
274295

275296

276297
class CofunctionAssignBlock(Block):

tests/firedrake/adjoint/test_split_and_subfunctions.py

+23
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def test_subfunctions_always_create_blocks():
183183
# force the subfunctions to be created
184184
_ = kappa.subfunctions
185185

186+
continue_annotation()
186187
with set_working_tape() as tape:
187188
u.assign(kappa.subfunctions[0])
188189
J = assemble(inner(u, u)*dx)
@@ -191,3 +192,25 @@ def test_subfunctions_always_create_blocks():
191192

192193
rf.derivative()
193194
assert control.block_variable.adj_value is not None, "Functional should depend on Control"
195+
196+
197+
@pytest.mark.skipcomplex
198+
def test_writing_to_subfunctions():
199+
with stop_annotating():
200+
mesh = UnitIntervalMesh(1)
201+
R = FunctionSpace(mesh, "R", 0)
202+
203+
kappa = Function(R).assign(2.0)
204+
u = Function(R)
205+
usub = u.subfunctions[0]
206+
207+
continue_annotation()
208+
with set_working_tape() as tape:
209+
u.assign(kappa)
210+
usub *= 2
211+
J = assemble(inner(u, u) * dx)
212+
print(f"{type(J) = }")
213+
rf = ReducedFunctional(J, Control(kappa), tape=tape)
214+
pause_annotation()
215+
216+
assert taylor_test(rf, kappa, Constant(0.1)) > 1.9

0 commit comments

Comments
 (0)