Skip to content

Commit 4193431

Browse files
committed
Add tests
1 parent c55b648 commit 4193431

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tests/firedrake/regression/test_restricted_function_space.py

+60
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,63 @@ def test_restricted_function_space_extrusion_stokes(ncells):
362362
# -- Actually, the ordering is the same.
363363
assert np.allclose(sol_res.subfunctions[0].dat.data_ro_with_halos, sol.subfunctions[0].dat.data_ro_with_halos)
364364
assert np.allclose(sol_res.subfunctions[1].dat.data_ro_with_halos, sol.subfunctions[1].dat.data_ro_with_halos)
365+
366+
367+
@pytest.mark.parametrize("names", [(None, None), (None, "name1"), ("name0", "name1")])
368+
def test_restrict_fieldsplit(names):
369+
mesh = UnitSquareMesh(2, 2)
370+
V = FunctionSpace(mesh, "CG", 1, name=names[0])
371+
Q = FunctionSpace(mesh, "CG", 2, name=names[1])
372+
Z = V * Q
373+
374+
z = Function(Z)
375+
test = TestFunction(Z)
376+
z_exact = Constant([1, -1])
377+
378+
F = inner(z - z_exact, test) * dx
379+
bcs = [DirichletBC(Z.sub(i), z_exact[i], (i+1, i+3)) for i in range(len(Z))]
380+
381+
problem = NonlinearVariationalProblem(F, z, bcs=bcs, restrict=True)
382+
solver = NonlinearVariationalSolver(problem, solver_parameters={
383+
"ksp_type": "preonly",
384+
"pc_type": "fieldsplit",
385+
"pc_fieldsplit_type": "additive",
386+
f"fieldsplit_{names[0] or 0}_pc_type": "lu",
387+
f"fieldsplit_{names[1] or 1}_pc_type": "lu"},
388+
options_prefix="")
389+
solver.solve()
390+
391+
# Test prefixes for the restricted spaces
392+
pc = solver.snes.ksp.pc
393+
for field, ksp in enumerate(pc.getFieldSplitSubKSP()):
394+
name = Z[field].name or field
395+
assert ksp.getOptionsPrefix() == f"fieldsplit_{name}_"
396+
397+
assert errornorm(z_exact[0], z.subfunctions[0]) < 1E-10
398+
assert errornorm(z_exact[1], z.subfunctions[1]) < 1E-10
399+
400+
401+
def test_restrict_python_pc():
402+
mesh = UnitSquareMesh(2, 2)
403+
x, y = SpatialCoordinate(mesh)
404+
V = FunctionSpace(mesh, "CG", 1)
405+
406+
u = Function(V)
407+
test = TestFunction(V)
408+
u_exact = x + y
409+
g = Function(V).interpolate(u_exact)
410+
411+
F = inner(u - u_exact, test) * dx
412+
bcs = [DirichletBC(V, g, 1), DirichletBC(V, u_exact, 2)]
413+
414+
problem = NonlinearVariationalProblem(F, u, bcs=bcs, restrict=True)
415+
solver = NonlinearVariationalSolver(problem, solver_parameters={
416+
"mat_type": "matfree",
417+
"ksp_type": "preonly",
418+
"pc_type": "python",
419+
"pc_python_type": "firedrake.AssembledPC",
420+
"assembled_pc_type": "lu"},
421+
options_prefix="")
422+
solver.solve()
423+
424+
assert errornorm(u_exact, u) < 1E-10

0 commit comments

Comments
 (0)