@@ -197,50 +197,71 @@ class SubfunctionBlock(Block):
197
197
def __init__ (self , func , idx , ad_block_tag = None ):
198
198
super ().__init__ (ad_block_tag = ad_block_tag )
199
199
self .add_dependency (func )
200
- self .idx = idx
200
+ self .sub_idx = idx
201
201
202
202
def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
203
203
prepared = None ):
204
204
eval_adj = firedrake .Cofunction (block_variable .output .function_space ().dual ())
205
205
if type (adj_inputs [0 ]) is firedrake .Cofunction :
206
- eval_adj .sub (self .idx ).assign (adj_inputs [0 ])
206
+ eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ])
207
207
else :
208
- eval_adj .sub (self .idx ).assign (adj_inputs [0 ].function )
208
+ eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ].function )
209
209
return eval_adj
210
210
211
211
def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx ,
212
212
prepared = None ):
213
- return firedrake .Function .sub (tlm_inputs [0 ], self .idx )
213
+ return firedrake .Function .sub (tlm_inputs [0 ], self .sub_idx )
214
214
215
215
def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
216
216
block_variable , idx ,
217
217
relevant_dependencies , prepared = None ):
218
218
eval_hessian = firedrake .Cofunction (block_variable .output .function_space ().dual ())
219
- eval_hessian .sub (self .idx ).assign (hessian_inputs [0 ])
219
+ eval_hessian .sub (self .sub_idx ).assign (hessian_inputs [0 ])
220
220
return eval_hessian
221
221
222
222
def recompute_component (self , inputs , block_variable , idx , prepared ):
223
223
return maybe_disk_checkpoint (
224
- firedrake .Function .sub (inputs [0 ], self .idx )
224
+ firedrake .Function .sub (inputs [0 ], self .sub_idx )
225
225
)
226
226
227
227
def __str__ (self ):
228
- return f"{ self .get_dependencies ()[0 ]} [{ self .idx } ]"
228
+ return f"{ self .get_dependencies ()[0 ]} [{ self .sub_idx } ]"
229
229
230
230
231
231
class FunctionMergeBlock (Block ):
232
232
def __init__ (self , func , idx , ad_block_tag = None ):
233
233
super ().__init__ (ad_block_tag = ad_block_tag )
234
234
self .add_dependency (func )
235
- self .idx = idx
235
+ self .sub_idx = idx
236
236
for output in func ._ad_outputs :
237
237
self .add_dependency (output )
238
238
239
239
def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
240
240
prepared = None ):
241
+ # The merge block appears whenever a subfunction is the output of a block.
242
+ # This means that the subfunction has been modified, so we need to make
243
+ # sure that this modification is accounted for when evaluating the adjoint.
244
+ #
245
+ # When recomputing the merge block, the indexed subfunction in the full
246
+ # Function is completely overwritten, meaning that the pre-existing value
247
+ # of the subfunction in the full function is ignored.
248
+ # The equivalent adjoint operation is to:
249
+ # 1. send the subfunction component of the adjoint value back up
250
+ # the branch of the tape corresponding to the subfunction
251
+ # dependency (idx=0).
252
+ # 2. zero out the subfunction component of the adjoint value sent
253
+ # back up the full Function branch of the tape (idx=1).
254
+ # This means that when the adjoint values of each branch are combined
255
+ # after the SubfunctionBlock only the adjoint value from the subfunction
256
+ # branch is used.
257
+ #
258
+ # See https://github.com/firedrakeproject/firedrake/pull/4177 for more
259
+ # detail and for diagrams of the tape produced when accessing subfunctions.
260
+
241
261
if idx == 0 :
242
- return adj_inputs [0 ].subfunctions [self .idx ]
262
+ return adj_inputs [0 ].subfunctions [self .sub_idx ]. copy ( deepcopy = True )
243
263
else :
264
+ adj_inputs [0 ].subfunctions [self .sub_idx ].zero ()
244
265
return adj_inputs [0 ]
245
266
246
267
def evaluate_tlm (self , markings = False ):
@@ -253,7 +274,7 @@ def evaluate_tlm(self, markings=False):
253
274
fs = output .output .function_space ()
254
275
f = type (output .output )(fs )
255
276
output .add_tlm_output (
256
- type (output .output ).assign (f .sub (self .idx ), tlm_input )
277
+ type (output .output ).assign (f .sub (self .sub_idx ), tlm_input )
257
278
)
258
279
259
280
def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
@@ -265,12 +286,12 @@ def recompute_component(self, inputs, block_variable, idx, prepared):
265
286
sub_func = inputs [0 ]
266
287
parent_in = inputs [1 ]
267
288
parent_out = type (parent_in )(parent_in )
268
- parent_out .sub (self .idx ).assign (sub_func )
289
+ parent_out .sub (self .sub_idx ).assign (sub_func )
269
290
return maybe_disk_checkpoint (parent_out )
270
291
271
292
def __str__ (self ):
272
293
deps = self .get_dependencies ()
273
- return f"{ deps [1 ]} [{ self .idx } ].assign({ deps [0 ]} )"
294
+ return f"{ deps [1 ]} [{ self .sub_idx } ].assign({ deps [0 ]} )"
274
295
275
296
276
297
class CofunctionAssignBlock (Block ):
0 commit comments