Skip to content

Commit 237c17e

Browse files
authored
Use petsctools in tests (#4447)
* use petsctools for the OptionsManager in adjoint test
1 parent ac5e291 commit 237c17e

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tests/firedrake/adjoint/test_optimisation.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from firedrake import *
77
from firedrake.adjoint import *
88
from pyadjoint import Block, MinimizationProblem, TAOSolver, get_working_tape
9-
from pyadjoint.optimization.tao_solver import OptionsManager, PETScVecInterface
9+
from pyadjoint.optimization.tao_solver import PETScVecInterface
10+
import petsctools
1011

1112

1213
@pytest.fixture(autouse=True)
@@ -68,7 +69,9 @@ def test_petsc_roundtrip_multiple():
6869
def minimize_tao_lmvm(rf, *, convert_options=None):
6970
problem = MinimizationProblem(rf)
7071
solver = TAOSolver(problem, {"tao_type": "lmvm",
71-
"tao_gatol": 1.0e-7,
72+
"tao_monitor": None,
73+
"tao_converged_reason": None,
74+
"tao_gatol": 1.0e-5,
7275
"tao_grtol": 0.0,
7376
"tao_gttol": 0.0},
7477
convert_options=convert_options)
@@ -78,7 +81,9 @@ def minimize_tao_lmvm(rf, *, convert_options=None):
7881
def minimize_tao_nls(rf, *, convert_options=None):
7982
problem = MinimizationProblem(rf)
8083
solver = TAOSolver(problem, {"tao_type": "nls",
81-
"tao_gatol": 1.0e-7,
84+
"tao_monitor": None,
85+
"tao_converged_reason": None,
86+
"tao_gatol": 1.0e-5,
8287
"tao_grtol": 0.0,
8388
"tao_gttol": 0.0},
8489
convert_options=convert_options)
@@ -218,11 +223,13 @@ def mult(self, A, x, y):
218223
M_mat.setUp()
219224

220225
mfn = SLEPc.MFN().create(comm=comm)
221-
options = OptionsManager(mfn_parameters, None)
222-
options.set_default_parameter("fn_type", "sqrt")
226+
petsctools.attach_options(
227+
mfn, parameters=mfn_parameters,
228+
options_prefix=None)
229+
petsctools.set_default_parameter(mfn, "fn_type", "sqrt")
223230
mfn.setOperator(M_mat)
224231

225-
options.set_from_options(mfn)
232+
petsctools.set_from_options(mfn)
226233
mfn.setUp()
227234
if mfn.getFN().getType() != SLEPc.FN.Type.SQRT:
228235
raise ValueError("Invalid FN type")
@@ -234,7 +241,8 @@ def mult(self, A, x, y):
234241
if y.norm(PETSc.NormType.NORM_INFINITY) == 0:
235242
x.zeroEntries()
236243
else:
237-
mfn.solve(y, x)
244+
with petsctools.inserted_options(mfn):
245+
mfn.solve(y, x)
238246
if mfn.getConvergedReason() <= 0:
239247
raise RuntimeError("Convergence failure")
240248

@@ -281,10 +289,11 @@ def test_simple_inversion_riesz_representation(tao_type):
281289
mfn_parameters = {"mfn_type": "krylov",
282290
"mfn_tol": 1.0e-12}
283291
tao_parameters = {"tao_type": tao_type,
292+
"tao_monitor": None,
293+
"tao_converged_reason": None,
284294
"tao_gatol": 1.0e-5,
285295
"tao_grtol": 0.0,
286-
"tao_gttol": 0.0,
287-
"tao_monitor": None}
296+
"tao_gttol": 0.0}
288297

289298
with stop_annotating():
290299
mesh = UnitIntervalMesh(10)
@@ -355,6 +364,8 @@ def test_tao_bounds():
355364
lb = 0.5 - 7.0 / 11.0
356365
problem = MinimizationProblem(rf, bounds=(lb, None))
357366
solver = TAOSolver(problem, {"tao_type": "bnls",
367+
"tao_monitor": None,
368+
"tao_converged_reason": None,
358369
"tao_gatol": 1.0e-7,
359370
"tao_grtol": 0.0,
360371
"tao_gttol": 0.0})

0 commit comments

Comments
 (0)