Skip to content

Commit fd42aa1

Browse files
committed
RestrictedFunctionSpace: support geometric multigrid
1 parent 475fe05 commit fd42aa1

File tree

7 files changed

+75
-30
lines changed

7 files changed

+75
-30
lines changed

firedrake/dmhooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def refine(dm, comm):
459459
if hasattr(V, "_fine"):
460460
fdm = V._fine.dm
461461
else:
462-
V._fine = firedrake.FunctionSpace(hierarchy[level + 1], V.ufl_element())
462+
V._fine = V.reconstruct(mesh=hierarchy[level + 1])
463463
fdm = V._fine.dm
464464
V._fine._coarse = V
465465
return fdm

firedrake/functionspaceimpl.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def collapse(self):
357357
return type(self).create(self.topological.collapse(), self.mesh())
358358

359359
@classmethod
360-
def make_function_space(cls, mesh, element, name=None):
360+
def make_function_space(cls, mesh, element, name=None, boundary_set=None):
361361
r"""Factory method for :class:`.WithGeometryBase`."""
362362
mesh.init()
363363
topology = mesh.topology
@@ -376,12 +376,18 @@ def make_function_space(cls, mesh, element, name=None):
376376
if mesh is not topology:
377377
# Create a concrete WithGeometry or FiredrakeDualSpace on this mesh
378378
new = cls.create(new, mesh)
379+
380+
if boundary_set:
381+
new = RestrictedFunctionSpace(new, boundary_set=boundary_set)
382+
if mesh is not topology:
383+
new = cls.create(new, mesh)
379384
return new
380385

381-
def reconstruct(self, mesh=None, name=None, **kwargs):
386+
def reconstruct(self, mesh=None, element=None, name=None, **kwargs):
382387
r"""Reconstruct this :class:`.WithGeometryBase` .
383388
384389
:kwarg mesh: the new :func:`~.Mesh` (defaults to same mesh)
390+
:kwarg element: the new :class:`finat.ufl.FiniteElement` (defaults to same element)
385391
:kwarg name: the new name (defaults to None)
386392
:returns: the new function space of the same class as ``self``.
387393
@@ -404,12 +410,14 @@ def reconstruct(self, mesh=None, name=None, **kwargs):
404410
if mesh is None:
405411
mesh = V_parent.mesh()
406412

407-
element = V_parent.ufl_element()
413+
if element is None:
414+
element = V_parent.ufl_element()
408415
cell = mesh.topology.ufl_cell()
409416
if len(kwargs) > 0 or element.cell != cell:
410417
element = element.reconstruct(cell=cell, **kwargs)
411418

412-
V = type(self).make_function_space(mesh, element, name=name)
419+
V = type(self).make_function_space(mesh, element, name=name,
420+
boundary_set=V_parent.boundary_set)
413421
for i in reversed(indices):
414422
V = V.sub(i)
415423
return V

firedrake/mg/interface.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,17 @@ def prolong(coarse, fine):
5858
repeat = (fine_level - coarse_level)*refinements_per_level
5959
next_level = coarse_level * refinements_per_level
6060

61-
element = Vc.ufl_element()
6261
meshes = hierarchy._meshes
6362
for j in range(repeat):
6463
next_level += 1
6564
if j == repeat - 1:
6665
next = fine
6766
Vf = fine.function_space()
6867
else:
69-
Vf = firedrake.FunctionSpace(meshes[next_level], element)
68+
Vf = Vc.reconstruct(mesh=meshes[next_level])
7069
next = firedrake.Function(Vf)
7170

72-
coarse_coords = Vc.mesh().coordinates
71+
coarse_coords = get_coordinates(Vc)
7372
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
7473
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
7574
kernel = kernels.prolong_kernel(coarse)
@@ -119,7 +118,6 @@ def restrict(fine_dual, coarse_dual):
119118
repeat = (fine_level - coarse_level)*refinements_per_level
120119
next_level = fine_level * refinements_per_level
121120

122-
element = Vc.ufl_element()
123121
meshes = hierarchy._meshes
124122

125123
for j in range(repeat):
@@ -128,15 +126,15 @@ def restrict(fine_dual, coarse_dual):
128126
coarse_dual.dat.zero()
129127
next = coarse_dual
130128
else:
131-
Vc = firedrake.FunctionSpace(meshes[next_level], element)
129+
Vc = Vf.reconstruct(mesh=meshes[next_level])
132130
next = firedrake.Cofunction(Vc.dual())
133131
Vc = next.function_space()
134132
# XXX: Should be able to figure out locations by pushing forward
135133
# reference cell node locations to physical space.
136134
# x = \sum_i c_i \phi_i(x_hat)
137-
node_locations = utils.physical_node_locations(Vf)
135+
node_locations = utils.physical_node_locations(Vf.dual())
138136

139-
coarse_coords = Vc.mesh().coordinates
137+
coarse_coords = get_coordinates(Vc.dual())
140138
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
141139
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
142140
# Have to do this, because the node set core size is not right for
@@ -195,7 +193,6 @@ def inject(fine, coarse):
195193
repeat = (fine_level - coarse_level)*refinements_per_level
196194
next_level = fine_level * refinements_per_level
197195

198-
element = Vc.ufl_element()
199196
meshes = hierarchy._meshes
200197

201198
for j in range(repeat):
@@ -205,12 +202,12 @@ def inject(fine, coarse):
205202
next = coarse
206203
Vc = next.function_space()
207204
else:
208-
Vc = firedrake.FunctionSpace(meshes[next_level], element)
205+
Vc = Vf.reconstruct(mesh=meshes[next_level])
209206
next = firedrake.Function(Vc)
210207
if not dg:
211208
node_locations = utils.physical_node_locations(Vc)
212209

213-
fine_coords = Vf.mesh().coordinates
210+
fine_coords = get_coordinates(Vf)
214211
coarse_node_to_fine_nodes = utils.coarse_node_to_fine_node_map(Vc, Vf)
215212
coarse_node_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())
216213

@@ -242,3 +239,11 @@ def inject(fine, coarse):
242239
fine = next
243240
Vf = Vc
244241
return coarse
242+
243+
244+
def get_coordinates(V):
245+
coords = V.mesh().coordinates
246+
if V.boundary_set:
247+
W = V.reconstruct(element=coords.function_space().ufl_element())
248+
coords = firedrake.Function(W).interpolate(coords)
249+
return coords

firedrake/mg/ufl_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
186186
if isinstance(g, firedrake.Function) and hasattr(g, "_child"):
187187
manager.inject(g, g._child)
188188

189-
V = problem.u.function_space()
189+
V = problem.u_restrict.function_space()
190190
if not hasattr(V, "_coarse"):
191191
# The hook is persistent and cumulative, but also problem-independent.
192192
# Therefore, we are only adding it once.
@@ -201,7 +201,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
201201
for c in coefficients:
202202
coefficient_mapping[c] = self(c, self, coefficient_mapping=coefficient_mapping)
203203

204-
u = coefficient_mapping[problem.u]
204+
u = coefficient_mapping[problem.u_restrict]
205205

206206
bcs = [self(bc, self) for bc in problem.bcs]
207207
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
@@ -277,7 +277,7 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
277277
if isinstance(val, (firedrake.Function, firedrake.Cofunction)):
278278
V = val.function_space()
279279
coarseneddm = V.dm
280-
parentdm = get_parent(context._problem.u.function_space().dm)
280+
parentdm = get_parent(context._problem.u_restrict.function_space().dm)
281281

282282
# Now attach the hook to the parent DM
283283
if get_appctx(coarseneddm) is None:
@@ -369,8 +369,8 @@ def create_interpolation(dmc, dmf):
369369

370370
manager = get_transfer_manager(dmf)
371371

372-
V_c = cctx._problem.u.function_space()
373-
V_f = fctx._problem.u.function_space()
372+
V_c = cctx._problem.u_restrict.function_space()
373+
V_f = fctx._problem.u_restrict.function_space()
374374

375375
row_size = V_f.dof_dset.layout_vec.getSizes()
376376
col_size = V_c.dof_dset.layout_vec.getSizes()
@@ -395,8 +395,8 @@ def create_injection(dmc, dmf):
395395

396396
manager = get_transfer_manager(dmf)
397397

398-
V_c = cctx._problem.u.function_space()
399-
V_f = fctx._problem.u.function_space()
398+
V_c = cctx._problem.u_restrict.function_space()
399+
V_f = fctx._problem.u_restrict.function_space()
400400

401401
row_size = V_f.dof_dset.layout_vec.getSizes()
402402
col_size = V_c.dof_dset.layout_vec.getSizes()

firedrake/mg/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def fine_node_to_coarse_node_map(Vf, Vc):
1212
if len(Vf) > 1:
1313
assert len(Vf) == len(Vc)
14-
return op2.MixedMap(fine_node_to_coarse_node_map(f, c) for f, c in zip(Vf, Vc))
14+
return op2.MixedMap(map(fine_node_to_coarse_node_map, Vf, Vc))
1515
mesh = Vf.mesh()
1616
assert hasattr(mesh, "_shared_data_cache")
1717
hierarchyf, levelf = get_level(Vf.mesh())
@@ -49,7 +49,7 @@ def fine_node_to_coarse_node_map(Vf, Vc):
4949
def coarse_node_to_fine_node_map(Vc, Vf):
5050
if len(Vf) > 1:
5151
assert len(Vf) == len(Vc)
52-
return op2.MixedMap(coarse_node_to_fine_node_map(f, c) for f, c in zip(Vf, Vc))
52+
return op2.MixedMap(map(coarse_node_to_fine_node_map, Vf, Vc))
5353
mesh = Vc.mesh()
5454
assert hasattr(mesh, "_shared_data_cache")
5555
hierarchyf, levelf = get_level(Vf.mesh())
@@ -146,7 +146,8 @@ def physical_node_locations(V):
146146
try:
147147
return cache[key]
148148
except KeyError:
149-
Vc = firedrake.VectorFunctionSpace(mesh, element)
149+
Vc = V.reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension()))
150+
150151
# FIXME: This is unsafe for DG coordinates and CG target spaces.
151152
locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc))
152153
return cache.setdefault(key, locations)

pyop2/types/dat.py

-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _wrapper_cache_key_(self):
105105
@utils.validate_in(('access', _modes, ex.ModeValueError))
106106
def __call__(self, access, path=None):
107107
from pyop2.parloop import DatLegacyArg
108-
109108
if conf.configuration["type_check"] and path and path.toset != self.dataset.set:
110109
raise ex.MapValueError("To Set of Map does not match Set of Dat.")
111110
return DatLegacyArg(self, path, access)

tests/firedrake/regression/test_restricted_function_space.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,11 @@ def test_restrict_fieldsplit(names):
401401

402402
def test_restrict_python_pc():
403403
mesh = UnitSquareMesh(2, 2)
404-
x, y = SpatialCoordinate(mesh)
405404
V = FunctionSpace(mesh, "CG", 1)
406-
407405
u = Function(V)
408406
test = TestFunction(V)
407+
408+
x, y = SpatialCoordinate(mesh)
409409
u_exact = x + y
410410
g = Function(V).interpolate(u_exact)
411411

@@ -419,8 +419,40 @@ def test_restrict_python_pc():
419419
"ksp_type": "preonly",
420420
"pc_type": "python",
421421
"pc_python_type": "firedrake.AssembledPC",
422-
"assembled_pc_type": "lu"},
423-
options_prefix="")
422+
"assembled_pc_type": "lu"})
423+
solver.solve()
424+
425+
assert errornorm(u_exact, u) < 1E-10
426+
427+
428+
def test_restrict_multigrid():
429+
base = UnitSquareMesh(2, 2)
430+
refine = 2
431+
mh = MeshHierarchy(base, refine)
432+
mesh = mh[-1]
433+
434+
V = FunctionSpace(mesh, "CG", 1)
435+
u = Function(V)
436+
test = TestFunction(V)
437+
438+
x, y = SpatialCoordinate(mesh)
439+
u_exact = x + y
440+
g = Function(V).interpolate(u_exact)
441+
442+
F = inner(grad(u - u_exact), grad(test)) * dx
443+
bcs = [DirichletBC(V, g, 1), DirichletBC(V, u_exact, 2)]
444+
445+
problem = NonlinearVariationalProblem(F, u, bcs=bcs, restrict=True)
446+
solver = NonlinearVariationalSolver(problem, solver_parameters={
447+
"snes_type": "ksponly",
448+
"ksp_type": "cg",
449+
"ksp_rtol": 1E-10,
450+
"ksp_max_it": 10,
451+
"ksp_monitor": None,
452+
"pc_type": "mg",
453+
"mg_levels_ksp_type": "chebyshev",
454+
"mg_levels_pc_type": "jacobi",
455+
"mg_coarse_pc_type": "lu"})
424456
solver.solve()
425457

426458
assert errornorm(u_exact, u) < 1E-10

0 commit comments

Comments
 (0)