Skip to content

Commit 3c63f6c

Browse files
committed
RestrictedFunctionSpace: support p-multigrid
1 parent fd42aa1 commit 3c63f6c

File tree

5 files changed

+37
-35
lines changed

5 files changed

+37
-35
lines changed

firedrake/mg/embedded.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,17 @@ def cache(self, V):
9090
except KeyError:
9191
return self.caches.setdefault(key, TransferManager.Cache(*key))
9292

93+
def cache_key(self, V):
94+
return (V.dim(), tuple(V.boundary_set))
95+
9396
def V_dof_weights(self, V):
9497
"""Dof weights for averaging projection.
9598
9699
:arg V: function space to compute weights for.
97100
:returns: A PETSc Vec.
98101
"""
99102
cache = self.cache(V)
100-
key = V.dim()
103+
key = self.cache_key(V)
101104
try:
102105
return cache._V_dof_weights[key]
103106
except KeyError:
@@ -122,7 +125,7 @@ def V_DG_mass(self, V, DG):
122125
:returns: A PETSc Mat mapping from V -> DG
123126
"""
124127
cache = self.cache(V)
125-
key = V.dim()
128+
key = self.cache_key(V)
126129
try:
127130
return cache._V_DG_mass[key]
128131
except KeyError:
@@ -153,7 +156,7 @@ def V_approx_inv_mass(self, V, DG):
153156
:returns: A PETSc Mat mapping from V -> DG.
154157
"""
155158
cache = self.cache(V)
156-
key = V.dim()
159+
key = self.cache_key(V)
157160
try:
158161
return cache._V_approx_inv_mass[key]
159162
except KeyError:
@@ -171,7 +174,7 @@ def V_inv_mass_ksp(self, V):
171174
:returns: A PETSc KSP for inverting (V, V).
172175
"""
173176
cache = self.cache(V)
174-
key = V.dim()
177+
key = self.cache_key(V)
175178
try:
176179
return cache._V_inv_mass_ksp[key]
177180
except KeyError:
@@ -193,7 +196,7 @@ def DG_work(self, V):
193196
"""
194197
needs_dual = ufl.duals.is_dual(V)
195198
cache = self.cache(V)
196-
key = (V.dim(), needs_dual)
199+
key = self.cache_key(V) + (needs_dual,)
197200
try:
198201
return cache._DG_work[key]
199202
except KeyError:
@@ -210,7 +213,7 @@ def work_vec(self, V):
210213
:returns: A PETSc Vec for V.
211214
"""
212215
cache = self.cache(V)
213-
key = V.dim()
216+
key = self.cache_key(V)
214217
try:
215218
return cache._work_vec[key]
216219
except KeyError:

firedrake/mg/interface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def restrict(fine_dual, coarse_dual):
127127
next = coarse_dual
128128
else:
129129
Vc = Vf.reconstruct(mesh=meshes[next_level])
130-
next = firedrake.Cofunction(Vc.dual())
130+
next = firedrake.Cofunction(Vc)
131131
Vc = next.function_space()
132132
# XXX: Should be able to figure out locations by pushing forward
133133
# reference cell node locations to physical space.

firedrake/mg/utils.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ def fine_node_to_coarse_node_map(Vf, Vc):
2525
if levelc + increment != levelf:
2626
raise ValueError("Can't map between level %s and level %s" % (levelc, levelf))
2727

28-
key = (entity_dofs_key(Vc.finat_element.entity_dofs())
29-
+ entity_dofs_key(Vf.finat_element.entity_dofs())
30-
+ (levelc, levelf))
31-
28+
key = _cache_key(Vc, Vf)
3229
cache = mesh._shared_data_cache["hierarchy_fine_node_to_coarse_node_map"]
3330
try:
3431
return cache[key]
@@ -63,10 +60,7 @@ def coarse_node_to_fine_node_map(Vc, Vf):
6360
if levelc + increment != levelf:
6461
raise ValueError("Can't map between level %s and level %s" % (levelc, levelf))
6562

66-
key = (entity_dofs_key(Vc.finat_element.entity_dofs())
67-
+ entity_dofs_key(Vf.finat_element.entity_dofs())
68-
+ (levelc, levelf))
69-
63+
key = _cache_key(Vc, Vf)
7064
cache = mesh._shared_data_cache["hierarchy_coarse_node_to_fine_node_map"]
7165
try:
7266
return cache[key]
@@ -101,7 +95,7 @@ def coarse_cell_to_fine_node_map(Vc, Vf):
10195
if levelc + increment != levelf:
10296
raise ValueError("Can't map between level %s and level %s" % (levelc, levelf))
10397

104-
key = (entity_dofs_key(Vf.finat_element.entity_dofs()) + (levelc, levelf))
98+
key = _cache_key(Vc, Vf)
10599
cache = mesh._shared_data_cache["hierarchy_coarse_cell_to_fine_node_map"]
106100
try:
107101
return cache[key]
@@ -142,7 +136,7 @@ def physical_node_locations(V):
142136
# This is a defaultdict, so the first time we access the key we
143137
# get a fresh dict for the cache.
144138
cache = mesh._geometric_shared_data_cache["hierarchy_physical_node_locations"]
145-
key = element
139+
key = (element, tuple(V.boundary_set))
146140
try:
147141
return cache[key]
148142
except KeyError:
@@ -172,3 +166,10 @@ def get_level(obj):
172166
def has_level(obj):
173167
"""Does the provided object have level info?"""
174168
return hasattr(obj.topological, "__level_info__")
169+
170+
171+
def _cache_key(Vc, Vf):
172+
return (entity_dofs_key(Vc.finat_element.entity_dofs())
173+
+ entity_dofs_key(Vf.finat_element.entity_dofs())
174+
+ (Vc.dim(), Vf.dim())
175+
+ (tuple(Vc.boundary_set), tuple(Vf.boundary_set)))

firedrake/preconditioners/pmg.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -181,21 +181,21 @@ def coarsen(self, fdm, comm):
181181
assert parent is not None
182182

183183
test, trial = fctx.J.arguments()
184-
fV = test.function_space()
184+
fV = trial.function_space()
185185
cele = self.coarsen_element(fV.ufl_element())
186186

187187
# Have we already done this?
188188
cctx = fctx._coarse
189189
if cctx is not None:
190-
cV = cctx.J.arguments()[0].function_space()
191-
if (cV.ufl_element() == cele) and (cV.mesh() == fV.mesh()):
190+
cV = cctx.J.arguments()[1].function_space()
191+
if (cV.ufl_element() == cele) and (cV.mesh() == fV.mesh()) and all(cV_.boundary_set == fV_.boundary_set for cV_, fV_ in zip(cV, fV)):
192192
return cV.dm
193193

194-
cV = firedrake.FunctionSpace(fV.mesh(), cele)
194+
cV = fV.reconstruct(element=cele)
195195
cdm = cV.dm
196196

197197
fproblem = fctx._problem
198-
fu = fproblem.u
198+
fu = fproblem.u_restrict
199199
cu = firedrake.Function(cV)
200200

201201
fdeg = PMGBase.max_degree(fV.ufl_element())
@@ -370,8 +370,8 @@ def create_transfer(self, mat_type, cctx, fctx, cbcs, fbcs):
370370
construct_mat = prolongation_matrix_aij
371371
else:
372372
raise ValueError("Unknown matrix type")
373-
cV = cctx.J.arguments()[0].function_space()
374-
fV = fctx.J.arguments()[0].function_space()
373+
cV = cctx._problem.u_restrict.function_space()
374+
fV = fctx._problem.u_restrict.function_space()
375375
cbcs = tuple(cctx._problem.bcs) if cbcs else tuple()
376376
fbcs = tuple(fctx._problem.bcs) if fbcs else tuple()
377377
return cache.setdefault(key, construct_mat(cV, fV, cbcs, fbcs))
@@ -1179,7 +1179,7 @@ def make_permutation_code(V, vshape, pshape, t_in, t_out, array_name):
11791179

11801180
def reference_value_space(V):
11811181
element = finat.ufl.WithMapping(V.ufl_element(), mapping="identity")
1182-
return firedrake.FunctionSpace(V.mesh(), element)
1182+
return V.collapse().reconstruct(element=element)
11831183

11841184

11851185
class StandaloneInterpolationMatrix(object):
@@ -1206,13 +1206,13 @@ def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
12061206
self.Vf = reference_value_space(self.Vf)
12071207
self.uc = firedrake.Function(self.Vc, val=self.uc.dat)
12081208
self.uf = firedrake.Function(self.Vf, val=self.uf.dat)
1209-
self.Vc_bcs = [bc.reconstruct(V=self.Vc) for bc in self.Vc_bcs]
1210-
self.Vf_bcs = [bc.reconstruct(V=self.Vf) for bc in self.Vf_bcs]
1209+
self.Vc_bcs = [bc.reconstruct(V=self.Vc, g=0) for bc in self.Vc_bcs]
1210+
self.Vf_bcs = [bc.reconstruct(V=self.Vf, g=0) for bc in self.Vf_bcs]
12111211

12121212
def work_function(self, V):
12131213
if isinstance(V, firedrake.Function):
12141214
return V
1215-
key = (V.ufl_element(), V.mesh())
1215+
key = (V.ufl_element(), V.mesh(), tuple(V.boundary_set))
12161216
try:
12171217
return self._cache_work[key]
12181218
except KeyError:
@@ -1337,17 +1337,14 @@ def make_blas_kernels(self, Vf, Vc):
13371337
restrict = [""]*5
13381338
# get embedding element for Vf with identity mapping and collocated vector component DOFs
13391339
try:
1340-
qelem = felem
1341-
if qelem.mapping() != "identity":
1342-
qelem = qelem.reconstruct(mapping="identity")
1343-
Qf = Vf if qelem == felem else firedrake.FunctionSpace(Vf.mesh(), qelem)
1340+
Qf = Vf if felem.mapping() == "identity" else Vf.reconstruct(mapping="identity")
13441341
mapping_output = make_mapping_code(Qf, cmapping, fmapping, "t0", "t1")
13451342
in_place_mapping = True
13461343
except Exception:
13471344
qelem = finat.ufl.FiniteElement("DQ", cell=felem.cell, degree=PMGBase.max_degree(felem))
13481345
if Vf.value_shape:
13491346
qelem = finat.ufl.TensorElement(qelem, shape=Vf.value_shape, symmetry=felem.symmetry())
1350-
Qf = firedrake.FunctionSpace(Vf.mesh(), qelem)
1347+
Qf = Vf.reconstruct(element=qelem)
13511348
mapping_output = make_mapping_code(Qf, cmapping, fmapping, "t0", "t1")
13521349

13531350
qshape = (Qf.block_size, Qf.finat_element.space_dimension())

tests/firedrake/multigrid/test_p_multigrid.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def mat_type(request):
137137
return request.param
138138

139139

140-
def test_p_multigrid_scalar(mesh, mat_type):
140+
@pytest.mark.parametrize("restrict", [False, True], ids=("unrestrict", "restrict"))
141+
def test_p_multigrid_scalar(mesh, mat_type, restrict):
141142
V = FunctionSpace(mesh, "CG", 4)
142143

143144
u = Function(V)
@@ -175,7 +176,7 @@ def test_p_multigrid_scalar(mesh, mat_type):
175176
"pmg_mg_coarse_mg_coarse_ksp_monitor": None,
176177
"pmg_mg_coarse_mg_coarse_pc_type": "gamg",
177178
"pmg_mg_coarse_mg_coarse_pc_gamg_threshold": 0}
178-
problem = NonlinearVariationalProblem(F, u, bcs)
179+
problem = NonlinearVariationalProblem(F, u, bcs, restrict=restrict)
179180
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
180181
solver.solve()
181182

0 commit comments

Comments
 (0)