Skip to content

RestrictedFunctionSpace: support Fieldsplit, multigrid, and python PC #4169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
10 changes: 6 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=Fal
V = V.sub(index)
if g is None:
g = self._original_arg
if isinstance(g, firedrake.Function) and g.function_space() != V:
g = firedrake.Function(V).interpolate(g)
if sub_domain is None:
sub_domain = self.sub_domain
if field is not None:
Expand Down Expand Up @@ -739,11 +741,11 @@ def restricted_function_space(V, ids):
return V

assert len(ids) == len(V)
spaces = [Vsub if len(boundary_set) == 0 else
firedrake.RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
for Vsub, boundary_set in zip(V, ids)]
spaces = [V_ if len(boundary_set) == 0 else
firedrake.RestrictedFunctionSpace(V_, boundary_set=boundary_set, name=V_.name)
for V_, boundary_set in zip(V, ids)]

if len(spaces) == 1:
return spaces[0]
else:
return firedrake.MixedFunctionSpace(spaces)
return firedrake.MixedFunctionSpace(spaces, name=V.name)
10 changes: 6 additions & 4 deletions firedrake/dmhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def get_function_space(dm):
:raises RuntimeError: if no function space was found.
"""
info = dm.getAttr("__fs_info__")
meshref, element, indices, (name, names) = info
meshref, element, indices, (name, names), boundary_sets = info
mesh = meshref()
if mesh is None:
raise RuntimeError("Somehow your mesh was collected, this should never happen")
V = firedrake.FunctionSpace(mesh, element, name=name)
if any(boundary_sets):
V = firedrake.bcs.restricted_function_space(V, boundary_sets)
if len(V) > 1:
for V_, name in zip(V, names):
V_.topological.name = name
Expand Down Expand Up @@ -93,8 +95,8 @@ def set_function_space(dm, V):
if len(V) > 1:
names = tuple(V_.name for V_ in V)
element = V.ufl_element()

info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
boundary_sets = tuple(V_.boundary_set for V_ in V)
info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names), boundary_sets)
dm.setAttr("__fs_info__", info)


Expand Down Expand Up @@ -457,7 +459,7 @@ def refine(dm, comm):
if hasattr(V, "_fine"):
fdm = V._fine.dm
else:
V._fine = firedrake.FunctionSpace(hierarchy[level + 1], V.ufl_element())
V._fine = V.reconstruct(mesh=hierarchy[level + 1])
fdm = V._fine.dm
V._fine._coarse = V
return fdm
Expand Down
29 changes: 23 additions & 6 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,23 @@ def make_function_space(cls, mesh, element, name=None):
new = cls.create(new, mesh)
return new

def reconstruct(self, mesh=None, name=None, **kwargs):
def reconstruct(self, mesh=None, element=None, boundary_set=None, name=None, **kwargs):
r"""Reconstruct this :class:`.WithGeometryBase` .

:kwarg mesh: the new :func:`~.Mesh` (defaults to same mesh)
:kwarg element: the new :class:`finat.ufl.FiniteElement` (defaults to same element)
:kwarg boundary_set: boundary subdomain labels defining a new
:func:`~.RestrictedFunctionSpace` (defaults to same boundary_set)
:kwarg name: the new name (defaults to None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we can reconstruct named spaces without too much worry, as the new name defaults to None

:returns: the new function space of the same class as ``self``.

Any extra kwargs are used to reconstruct the finite element.
For details see :meth:`finat.ufl.finiteelement.FiniteElement.reconstruct`.
"""
V_parent = self
if boundary_set is None:
boundary_set = V_parent.boundary_set

# Deal with ProxyFunctionSpace
indices = []
while True:
Expand All @@ -403,15 +409,19 @@ def reconstruct(self, mesh=None, name=None, **kwargs):

if mesh is None:
mesh = V_parent.mesh()
if element is None:
element = V_parent.ufl_element()

element = V_parent.ufl_element()
cell = mesh.topology.ufl_cell()
if len(kwargs) > 0 or element.cell != cell:
element = element.reconstruct(cell=cell, **kwargs)

V = type(self).make_function_space(mesh, element, name=name)
for i in reversed(indices):
V = V.sub(i)

if boundary_set:
V = RestrictedFunctionSpace(V, boundary_set=boundary_set, name=V.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For separate discussion: should RestrictedFunctionSpace be hidden from the user?

return V


Expand Down Expand Up @@ -876,6 +886,14 @@ class RestrictedFunctionSpace(FunctionSpace):
If using this class to solve or similar, a list of DirichletBCs will still
need to be specified on this space and passed into the function.
"""
def __new__(cls, function_space, boundary_set=frozenset(), name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use RestrictedFunctionSpace from functionspace.py (i.e. the function) instead

mesh = function_space.mesh()
if mesh is not mesh.topology:
V = RestrictedFunctionSpace(function_space.topological,
boundary_set=boundary_set, name=name)
return type(function_space).create(V, mesh)
return FunctionSpace.__new__(cls)

def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
boundary_set_ = []
Expand All @@ -901,8 +919,7 @@ def __init__(self, function_space, boundary_set=frozenset(), name=None):
function_space.ufl_element(),
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(map(str, self.boundary_set))))
self.name = name or function_space.name

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
Expand Down Expand Up @@ -1200,7 +1217,7 @@ class ProxyFunctionSpace(FunctionSpace):
"""
def __new__(cls, mesh, element, name=None):
topology = mesh.topology
self = super(ProxyFunctionSpace, cls).__new__(cls)
self = FunctionSpace.__new__(cls)
if mesh is not topology:
return WithGeometry.create(self, mesh)
else:
Expand Down Expand Up @@ -1255,7 +1272,7 @@ class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace):
"""
def __new__(cls, function_space, boundary_set=frozenset(), name=None):
topology = function_space._mesh.topology
self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls)
self = FunctionSpace.__new__(cls)
if function_space._mesh is not topology:
return WithGeometry.create(self, function_space._mesh)
else:
Expand Down
15 changes: 9 additions & 6 deletions firedrake/mg/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,17 @@ def cache(self, V):
except KeyError:
return self.caches.setdefault(key, TransferManager.Cache(*key))

def cache_key(self, V):
return (V.dim(), tuple(V.boundary_set))

def V_dof_weights(self, V):
"""Dof weights for averaging projection.

:arg V: function space to compute weights for.
:returns: A PETSc Vec.
"""
cache = self.cache(V)
key = V.dim()
key = self.cache_key(V)
try:
return cache._V_dof_weights[key]
except KeyError:
Expand All @@ -122,7 +125,7 @@ def V_DG_mass(self, V, DG):
:returns: A PETSc Mat mapping from V -> DG
"""
cache = self.cache(V)
key = V.dim()
key = self.cache_key(V)
try:
return cache._V_DG_mass[key]
except KeyError:
Expand Down Expand Up @@ -153,7 +156,7 @@ def V_approx_inv_mass(self, V, DG):
:returns: A PETSc Mat mapping from V -> DG.
"""
cache = self.cache(V)
key = V.dim()
key = self.cache_key(V)
try:
return cache._V_approx_inv_mass[key]
except KeyError:
Expand All @@ -171,7 +174,7 @@ def V_inv_mass_ksp(self, V):
:returns: A PETSc KSP for inverting (V, V).
"""
cache = self.cache(V)
key = V.dim()
key = self.cache_key(V)
try:
return cache._V_inv_mass_ksp[key]
except KeyError:
Expand All @@ -193,7 +196,7 @@ def DG_work(self, V):
"""
needs_dual = ufl.duals.is_dual(V)
cache = self.cache(V)
key = (V.dim(), needs_dual)
key = self.cache_key(V) + (needs_dual,)
try:
return cache._DG_work[key]
except KeyError:
Expand All @@ -210,7 +213,7 @@ def work_vec(self, V):
:returns: A PETSc Vec for V.
"""
cache = self.cache(V)
key = V.dim()
key = self.cache_key(V)
try:
return cache._work_vec[key]
except KeyError:
Expand Down
27 changes: 16 additions & 11 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,17 @@ def prolong(coarse, fine):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = coarse_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes
for j in range(repeat):
next_level += 1
if j == repeat - 1:
next = fine
Vf = fine.function_space()
else:
Vf = firedrake.FunctionSpace(meshes[next_level], element)
Vf = Vc.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vf)

coarse_coords = Vc.mesh().coordinates
coarse_coords = get_coordinates(Vc)
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
kernel = kernels.prolong_kernel(coarse)
Expand Down Expand Up @@ -119,7 +118,6 @@ def restrict(fine_dual, coarse_dual):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes

for j in range(repeat):
Expand All @@ -128,15 +126,15 @@ def restrict(fine_dual, coarse_dual):
coarse_dual.dat.zero()
next = coarse_dual
else:
Vc = firedrake.FunctionSpace(meshes[next_level], element)
next = firedrake.Cofunction(Vc.dual())
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Cofunction(Vc)
Vc = next.function_space()
# XXX: Should be able to figure out locations by pushing forward
# reference cell node locations to physical space.
# x = \sum_i c_i \phi_i(x_hat)
node_locations = utils.physical_node_locations(Vf)
node_locations = utils.physical_node_locations(Vf.dual())

coarse_coords = Vc.mesh().coordinates
coarse_coords = get_coordinates(Vc.dual())
fine_to_coarse = utils.fine_node_to_coarse_node_map(Vf, Vc)
fine_to_coarse_coords = utils.fine_node_to_coarse_node_map(Vf, coarse_coords.function_space())
# Have to do this, because the node set core size is not right for
Expand Down Expand Up @@ -195,7 +193,6 @@ def inject(fine, coarse):
repeat = (fine_level - coarse_level)*refinements_per_level
next_level = fine_level * refinements_per_level

element = Vc.ufl_element()
meshes = hierarchy._meshes

for j in range(repeat):
Expand All @@ -205,12 +202,12 @@ def inject(fine, coarse):
next = coarse
Vc = next.function_space()
else:
Vc = firedrake.FunctionSpace(meshes[next_level], element)
Vc = Vf.reconstruct(mesh=meshes[next_level])
next = firedrake.Function(Vc)
if not dg:
node_locations = utils.physical_node_locations(Vc)

fine_coords = Vf.mesh().coordinates
fine_coords = get_coordinates(Vf)
coarse_node_to_fine_nodes = utils.coarse_node_to_fine_node_map(Vc, Vf)
coarse_node_to_fine_coords = utils.coarse_node_to_fine_node_map(Vc, fine_coords.function_space())

Expand Down Expand Up @@ -242,3 +239,11 @@ def inject(fine, coarse):
fine = next
Vf = Vc
return coarse


def get_coordinates(V):
coords = V.mesh().coordinates
if V.boundary_set:
W = V.reconstruct(element=coords.function_space().ufl_element())
coords = firedrake.Function(W).interpolate(coords)
return coords
Comment on lines +244 to +249
Copy link
Contributor Author

@pbrubeck pbrubeck Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ksagiyam I need to restrict V.mesh().coordinates, is there a better way to do this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to do that here? Normally, we shouldn't need to do that as restricted function spaces are just regular function spaces with different DoF ordering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prolong and restrict need the coordinates in the same dof ordering as the restricted spaces.

14 changes: 7 additions & 7 deletions firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
if isinstance(g, firedrake.Function) and hasattr(g, "_child"):
manager.inject(g, g._child)

V = problem.u.function_space()
V = problem.u_restrict.function_space()
if not hasattr(V, "_coarse"):
# The hook is persistent and cumulative, but also problem-independent.
# Therefore, we are only adding it once.
Expand All @@ -201,7 +201,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
for c in coefficients:
coefficient_mapping[c] = self(c, self, coefficient_mapping=coefficient_mapping)

u = coefficient_mapping[problem.u]
u = coefficient_mapping[problem.u_restrict]

bcs = [self(bc, self) for bc in problem.bcs]
J = self(problem.J, self, coefficient_mapping=coefficient_mapping)
Expand Down Expand Up @@ -277,7 +277,7 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
if isinstance(val, (firedrake.Function, firedrake.Cofunction)):
V = val.function_space()
coarseneddm = V.dm
parentdm = get_parent(context._problem.u.function_space().dm)
parentdm = get_parent(context._problem.u_restrict.function_space().dm)

# Now attach the hook to the parent DM
if get_appctx(coarseneddm) is None:
Expand Down Expand Up @@ -369,8 +369,8 @@ def create_interpolation(dmc, dmf):

manager = get_transfer_manager(dmf)

V_c = cctx._problem.u.function_space()
V_f = fctx._problem.u.function_space()
V_c = cctx._problem.u_restrict.function_space()
V_f = fctx._problem.u_restrict.function_space()

row_size = V_f.dof_dset.layout_vec.getSizes()
col_size = V_c.dof_dset.layout_vec.getSizes()
Expand All @@ -395,8 +395,8 @@ def create_injection(dmc, dmf):

manager = get_transfer_manager(dmf)

V_c = cctx._problem.u.function_space()
V_f = fctx._problem.u.function_space()
V_c = cctx._problem.u_restrict.function_space()
V_f = fctx._problem.u_restrict.function_space()

row_size = V_f.dof_dset.layout_vec.getSizes()
col_size = V_c.dof_dset.layout_vec.getSizes()
Expand Down
Loading
Loading