From bbb6817f66472d08f27cc7aa75a651bc37a88dd7 Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Wed, 10 Sep 2025 19:43:31 +1000 Subject: [PATCH 1/2] Start to move TLM evaluation into NonlinearVariationalSolveBlock --- firedrake/adjoint_utils/blocks/solving.py | 65 ++++++++++++++++++- firedrake/adjoint_utils/variational_solver.py | 31 ++++++++- 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index 2c1fb4f876..bd2023a83b 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -29,6 +29,7 @@ class Solver(Enum): """Enum for solver types.""" FORWARD = 0 ADJOINT = 1 + TLM = 2 class GenericSolveBlock(Block): @@ -681,8 +682,11 @@ def _adjoint_solve(self, dJdu, compute_bdy): def _ad_assign_map(self, form, solver): if solver == Solver.FORWARD: count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map - else: + elif solver == Solver.ADJOINT: count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map + elif solver == Solver.TLM: + count_map = self._ad_solvers["tlm_lvs"]._problem._ad_count_map + assign_map = {} form_ad_count_map = dict((count_map[coeff], coeff) for coeff in form.coefficients()) @@ -717,9 +721,13 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD): problem = self._ad_solvers["forward_nlvs"]._problem self._ad_assign_coefficients(problem.F, solver) self._ad_assign_coefficients(problem.J, solver) - else: + elif solver == Solver.ADJOINT: self._ad_assign_coefficients( self._ad_solvers["adjoint_lvs"]._problem.J, solver) + elif solver == Solver.TLM: + self._ad_assign_coefficients( + self._ad_solvers["tlm_lvs"]._problem.J, solver + ) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): compute_bdy = self._should_compute_boundary_adjoint( @@ -796,6 +804,59 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, return dFdm + def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, + prepared=None): + F_form = prepared["form"] + dFdu = prepared["dFdu"] + + bcs = [] + dFdm = 0. + for block_variable in self.get_dependencies(): + tlm_value = block_variable.tlm_value + c = block_variable.output + c_rep = block_variable.saved_output + + if isinstance(c, firedrake.DirichletBC): + if tlm_value is None: + bcs.append(c.reconstruct(g=0)) + else: + bcs.append(tlm_value) + continue + elif isinstance(c, firedrake.MeshGeometry): + X = firedrake.SpatialCoordinate(c) + c_rep = X + + if tlm_value is None: + continue + + if c == self.func and not self.linear: + continue + + dFdm += firedrake.derivative(-F_form, c_rep, tlm_value) + + if isinstance(dFdm, float): + v = dFdu.arguments()[0] + dFdm = firedrake.inner( + firedrake.Constant(numpy.zeros(v.ufl_shape)), v + ) * firedrake.dx + + dFdm = ufl.algorithms.expand_derivatives(dFdm) + dFdm = firedrake.assemble(dFdm) + + # XXX I dunno how this works + self._ad_solver_replace_forms(Solver.TLM) + self._ad_solvers["tlm_lvs"].invalidate_jacobian() + # update RHS + self._ad_solvers["tlm_lvs"]._problem.F._components[1].assign(dFdm) + + self._ad_solvers["tlm_lvs"].solve() + return self._ad_solvers["tlm_lvs"]._problem.u + # return self._assemble_and_solve_tlm_eq( + # firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs), + # dFdm, dudm, bcs + # ) + + class ProjectBlock(SolveVarFormBlock): def __init__(self, v, V, output, bcs=[], *args, **kwargs): diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index d1b0af22ca..0cb3b4c713 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -49,7 +49,7 @@ def wrapper(self, problem, *args, **kwargs): self._ad_args = args self._ad_kwargs = kwargs self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None, - "recompute_count": 0} + "recompute_count": 0, "tlm_lvs": None} self._ad_adj_cache = {} return wrapper @@ -100,6 +100,14 @@ def wrapper(self, **kwargs): if self._ad_problem._constant_jacobian: self._ad_solvers["update_adjoint"] = False + if not self._ad_solvers["tlm_lvs"]: + with stop_annotating(): + self._ad_solvers["tlm_lvs"] = LinearVariationalSolver( + self._ad_tlm_lvs_problem(block, problem.F, problem.u_restrict) + ) + if self._ad_problem._constant_jacobian: + self._ad_solvers["update_tlm"] = False + block._ad_solvers = self._ad_solvers tape.add_block(block) @@ -151,7 +159,8 @@ def _ad_adj_lvs_problem(self, block, adj_F): # linear variational problem is created with a deep copy of the # `block.adj_F` coefficients. _ad_count_map, J_replace_map, _ = self._build_count_map( - adj_F, block._dependencies) + adj_F, block._dependencies, + ) lvp = LinearVariationalProblem( replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, bcs=tmp_problem.bcs, @@ -159,6 +168,24 @@ def _ad_adj_lvs_problem(self, block, adj_F): lvp._ad_count_map_update(_ad_count_map) return lvp + @no_annotations + def _ad_tlm_lvs_problem(self, block, F, u): + from firedrake import Function, Cofunction, LinearVariationalProblem + + lhs = derivative(F, u) + _ad_count_map, F_replace_map, _ = self._build_count_map(lhs, block._dependencies) + sol = Function(block.function_space) + rhs = Cofunction(block.function_space.dual()) + lvp = LinearVariationalProblem( + replace(lhs, F_replace_map), + rhs, + sol, + bcs=block._homogenize_bcs(), + constant_jacobian=self._ad_problem._constant_jacobian, + ) + lvp._ad_count_map_update(_ad_count_map) + return lvp + def _build_count_map(self, J, dependencies, F=None): from firedrake import Function From 538bd82dfa07405c061cf6085ebd60317ba593ac Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Wed, 10 Sep 2025 22:53:57 +1000 Subject: [PATCH 2/2] Start to move Hessian evaluation into NonlinearVariationalSolveBlock --- firedrake/adjoint_utils/blocks/solving.py | 39 ++++++++++++++----- firedrake/adjoint_utils/variational_solver.py | 33 +++++++++++++++- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index bd2023a83b..e89e0a4c4c 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -30,6 +30,7 @@ class Solver(Enum): FORWARD = 0 ADJOINT = 1 TLM = 2 + HESSIAN = 3 class GenericSolveBlock(Block): @@ -221,6 +222,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): return adj_sol, adj_sol_bdy + def _hessian_solve(self, *args): + return self._assemble_and_solve_adj_eq(*args) + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol)) return adj_sol_bdy.riesz_representation("l2") @@ -379,8 +383,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input, b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input, d2Fdu2) dFdu_form = firedrake.adjoint(dFdu_form) - adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b, - compute_bdy) + adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy) if self.adj2_cb is not None: self.adj2_cb(adj_sol2) if self.adj2_bdy_cb is not None and compute_bdy: @@ -679,6 +682,22 @@ def _adjoint_solve(self, dJdu, compute_bdy): u_sol, adj_sol_bdy, jac_adj, dJdu_copy) return u_sol, adj_sol_bdy + def _hessian_solve(self, adj_form, rhs, compute_bdy): + # self._ad_solver_replace_forms(Solver.HESSIAN) + # self._ad_solvers["hessian_lvs"].invalidate_jacobian() + self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs) + self._ad_solvers["hessian_lvs"].solve() + u_sol = self._ad_solvers["hessian_lvs"]._problem.u + + adj_sol_bdy = None + if compute_bdy: + jac_adj = self._ad_solvers["hessian_lvs"]._problem.J + adj_sol_bdy = self._compute_adj_bdy( + u_sol, adj_sol_bdy, jac_adj, rhs.copy() + ) + + return u_sol, adj_sol_bdy + def _ad_assign_map(self, form, solver): if solver == Solver.FORWARD: count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map @@ -697,8 +716,10 @@ def _ad_assign_map(self, form, solver): firedrake.Cofunction)): coeff_count = coeff.count() if coeff_count in form_ad_count_map: - assign_map[form_ad_count_map[coeff_count]] = \ - block_variable.saved_output + if solver == Solver.HESSIAN: + assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value + else: + assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output if ( solver == Solver.ADJOINT @@ -709,6 +730,7 @@ def _ad_assign_map(self, form, solver): if coeff_count in form_ad_count_map: assign_map[form_ad_count_map[coeff_count]] = \ block_variable.saved_output + return assign_map def _ad_assign_coefficients(self, form, solver): @@ -728,6 +750,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD): self._ad_assign_coefficients( self._ad_solvers["tlm_lvs"]._problem.J, solver ) + elif solver == Solver.HESSIAN: + self._ad_assign_coefficients( + self._ad_solvers["hessian_lvs"]._problem.J, solver + ) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): compute_bdy = self._should_compute_boundary_adjoint( @@ -851,11 +877,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, self._ad_solvers["tlm_lvs"].solve() return self._ad_solvers["tlm_lvs"]._problem.u - # return self._assemble_and_solve_tlm_eq( - # firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs), - # dFdm, dudm, bcs - # ) - class ProjectBlock(SolveVarFormBlock): diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 0cb3b4c713..8393000a5b 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -17,6 +17,7 @@ def wrapper(self, *args, **kwargs): self._ad_u = self.u_restrict self._ad_bcs = self.bcs self._ad_J = self.J + try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. @@ -27,8 +28,10 @@ def wrapper(self, *args, **kwargs): # Try again without expanding derivatives, # as dFdu might have been simplied to an empty Form self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) + except (TypeError, NotImplementedError): self._ad_adj_F = None + self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} self._ad_count_map = {} return wrapper @@ -49,7 +52,8 @@ def wrapper(self, problem, *args, **kwargs): self._ad_args = args self._ad_kwargs = kwargs self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None, - "recompute_count": 0, "tlm_lvs": None} + "recompute_count": 0, "tlm_lvs": None, + "hessian_lvs": None} self._ad_adj_cache = {} return wrapper @@ -100,6 +104,12 @@ def wrapper(self, **kwargs): if self._ad_problem._constant_jacobian: self._ad_solvers["update_adjoint"] = False + if not self._ad_solvers["hessian_lvs"]: + with stop_annotating(): + self._ad_solvers["hessian_lvs"] = LinearVariationalSolver( + self._ad_hessian_lvs_problem(block, problem._ad_adj_F), + ) + if not self._ad_solvers["tlm_lvs"]: with stop_annotating(): self._ad_solvers["tlm_lvs"] = LinearVariationalSolver( @@ -168,6 +178,27 @@ def _ad_adj_lvs_problem(self, block, adj_F): lvp._ad_count_map_update(_ad_count_map) return lvp + @no_annotations + def _ad_hessian_lvs_problem(self, block, adj_dFdu): + from firedrake import Function, Cofunction, LinearVariationalProblem + + bcs = block._homogenize_bcs() + adj_sol = Function(block.function_space) + right_hand_side = Cofunction(block.function_space.dual()) + tmp_problem = LinearVariationalProblem( + adj_dFdu, right_hand_side, adj_sol, bcs=bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + + _ad_count_map, J_replace_map, _ = self._build_count_map( + adj_dFdu, block._dependencies, + ) + lvp = LinearVariationalProblem( + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + bcs=tmp_problem.bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + lvp._ad_count_map_update(_ad_count_map) + return lvp + @no_annotations def _ad_tlm_lvs_problem(self, block, F, u): from firedrake import Function, Cofunction, LinearVariationalProblem