@@ -362,3 +362,63 @@ def test_restricted_function_space_extrusion_stokes(ncells):
362
362
# -- Actually, the ordering is the same.
363
363
assert np .allclose (sol_res .subfunctions [0 ].dat .data_ro_with_halos , sol .subfunctions [0 ].dat .data_ro_with_halos )
364
364
assert np .allclose (sol_res .subfunctions [1 ].dat .data_ro_with_halos , sol .subfunctions [1 ].dat .data_ro_with_halos )
365
+
366
+
367
+ @pytest .mark .parametrize ("names" , [(None , None ), (None , "name1" ), ("name0" , "name1" )])
368
+ def test_restrict_fieldsplit (names ):
369
+ mesh = UnitSquareMesh (2 , 2 )
370
+ V = FunctionSpace (mesh , "CG" , 1 , name = names [0 ])
371
+ Q = FunctionSpace (mesh , "CG" , 2 , name = names [1 ])
372
+ Z = V * Q
373
+
374
+ z = Function (Z )
375
+ test = TestFunction (Z )
376
+ z_exact = Constant ([1 , - 1 ])
377
+
378
+ F = inner (z - z_exact , test ) * dx
379
+ bcs = [DirichletBC (Z .sub (i ), z_exact [i ], (i + 1 , i + 3 )) for i in range (len (Z ))]
380
+
381
+ problem = NonlinearVariationalProblem (F , z , bcs = bcs , restrict = True )
382
+ solver = NonlinearVariationalSolver (problem , solver_parameters = {
383
+ "ksp_type" : "preonly" ,
384
+ "pc_type" : "fieldsplit" ,
385
+ "pc_fieldsplit_type" : "additive" ,
386
+ f"fieldsplit_{ names [0 ] or 0 } _pc_type" : "lu" ,
387
+ f"fieldsplit_{ names [1 ] or 1 } _pc_type" : "lu" },
388
+ options_prefix = "" )
389
+ solver .solve ()
390
+
391
+ # Test prefixes for the restricted spaces
392
+ pc = solver .snes .ksp .pc
393
+ for field , ksp in enumerate (pc .getFieldSplitSubKSP ()):
394
+ name = Z [field ].name or field
395
+ assert ksp .getOptionsPrefix () == f"fieldsplit_{ name } _"
396
+
397
+ assert errornorm (z_exact [0 ], z .subfunctions [0 ]) < 1E-10
398
+ assert errornorm (z_exact [1 ], z .subfunctions [1 ]) < 1E-10
399
+
400
+
401
+ def test_restrict_python_pc ():
402
+ mesh = UnitSquareMesh (2 , 2 )
403
+ x , y = SpatialCoordinate (mesh )
404
+ V = FunctionSpace (mesh , "CG" , 1 )
405
+
406
+ u = Function (V )
407
+ test = TestFunction (V )
408
+ u_exact = x + y
409
+ g = Function (V ).interpolate (u_exact )
410
+
411
+ F = inner (u - u_exact , test ) * dx
412
+ bcs = [DirichletBC (V , g , 1 ), DirichletBC (V , u_exact , 2 )]
413
+
414
+ problem = NonlinearVariationalProblem (F , u , bcs = bcs , restrict = True )
415
+ solver = NonlinearVariationalSolver (problem , solver_parameters = {
416
+ "mat_type" : "matfree" ,
417
+ "ksp_type" : "preonly" ,
418
+ "pc_type" : "python" ,
419
+ "pc_python_type" : "firedrake.AssembledPC" ,
420
+ "assembled_pc_type" : "lu" },
421
+ options_prefix = "" )
422
+ solver .solve ()
423
+
424
+ assert errornorm (u_exact , u ) < 1E-10
0 commit comments