Skip to content

Commit de0a450

Browse files
authored
Cell location on simplicial higher-order meshes (#4484)
* cache bounding box calculation. Recalculate if mesh coords are changed
1 parent 568ccc2 commit de0a450

File tree

4 files changed

+159
-47
lines changed

4 files changed

+159
-47
lines changed

firedrake/interpolation.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,13 @@ def __init__(
465465
V_dest = V.function_space() if isinstance(V, firedrake.Function) else V
466466
src_mesh = extract_unique_domain(expr)
467467
dest_mesh = as_domain(V_dest)
468+
if (
469+
ufl.cell.simplex(src_mesh.topological_dimension()) != src_mesh.ufl_cell()
470+
and numpy.any(numpy.asarray(src_mesh.ufl_coordinate_element().degree()) > 1)
471+
):
472+
raise NotImplementedError(
473+
"Cannot yet interpolate from non-simplicial higher-order meshes into other meshes."
474+
)
468475
src_mesh_gdim = src_mesh.geometric_dimension()
469476
dest_mesh_gdim = dest_mesh.geometric_dimension()
470477
if src_mesh_gdim != dest_mesh_gdim:
@@ -473,15 +480,6 @@ def __init__(
473480
)
474481
self.src_mesh = src_mesh
475482
self.dest_mesh = dest_mesh
476-
if numpy.any(
477-
numpy.asarray(src_mesh.coordinates.function_space().ufl_element().degree())
478-
> 1
479-
):
480-
# Need to implement vertex-only mesh immersion in high order meshes
481-
# for this to work.
482-
raise NotImplementedError(
483-
"Cannot yet interpolate from high order meshes to other meshes."
484-
)
485483

486484
self.sub_interpolators = []
487485

firedrake/mesh.py

Lines changed: 81 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import finat.ufl
77
import FIAT
88
import weakref
9+
from typing import Tuple
910
from collections import OrderedDict, defaultdict
1011
from collections.abc import Sequence
1112
from ufl.classes import ReferenceGrad
@@ -2289,6 +2290,7 @@ def __init__(self, coordinates):
22892290
# submesh
22902291
self.submesh_parent = None
22912292

2293+
self._bounding_box_coords = None
22922294
self._spatial_index = None
22932295
self._saved_coordinate_dat_version = coordinates.dat.dat_version
22942296

@@ -2457,54 +2459,71 @@ def clear_spatial_index(self):
24572459
the coordinate field)."""
24582460
self._spatial_index = None
24592461

2460-
@property
2461-
def spatial_index(self):
2462-
"""Spatial index to quickly find which cell contains a given point.
2462+
@utils.cached_property
2463+
def bounding_box_coords(self) -> Tuple[np.ndarray, np.ndarray] | None:
2464+
"""Calculates bounding boxes for spatial indexing.
24632465
2464-
Notes
2465-
-----
2466+
Returns
2467+
-------
2468+
Tuple of arrays of shape (num_cells, gdim) containing
2469+
the minimum and maximum coordinates of each cell's bounding box.
24662470
2467-
If this mesh has a :attr:`tolerance` property, which
2468-
should be a float, this tolerance is added to the extrama of the
2469-
spatial index so that points just outside the mesh, within tolerance,
2470-
can be found.
2471+
None if the geometric dimension is 1, since libspatialindex
2472+
does not support 1D.
24712473
2474+
Notes
2475+
-----
2476+
If we have a higher-order (bendy) mesh we project the mesh coordinates into
2477+
a Bernstein finite element space. Functions on a Bernstein element are
2478+
Bezier curves and are completely contained in the convex hull of the mesh nodes.
2479+
Hence the bounding box will contain the entire element.
24722480
"""
24732481
from firedrake import function, functionspace
24742482
from firedrake.parloops import par_loop, READ, MIN, MAX
24752483

2476-
if (
2477-
self._spatial_index
2478-
and self.coordinates.dat.dat_version == self._saved_coordinate_dat_version
2479-
):
2480-
return self._spatial_index
2481-
24822484
gdim = self.geometric_dimension()
24832485
if gdim <= 1:
24842486
info_red("libspatialindex does not support 1-dimension, falling back on brute force.")
24852487
return None
24862488

2489+
coord_element = self.ufl_coordinate_element()
2490+
coord_degree = coord_element.degree()
2491+
if ufl.cell.simplex(self.topological_dimension()) != self.ufl_cell():
2492+
# Non-simplex element, e.g. quad or tensor product
2493+
mesh = self
2494+
elif coord_degree == 1:
2495+
mesh = self
2496+
elif coord_element.family() == "Bernstein":
2497+
# Already have Bernstein coordinates, no need to project
2498+
mesh = self
2499+
else:
2500+
# For bendy meshes we project the coordinate function onto Bernstein
2501+
bernstein_fs = functionspace.VectorFunctionSpace(self, "Bernstein", coord_degree)
2502+
f = function.Function(bernstein_fs)
2503+
f.interpolate(self.coordinates)
2504+
mesh = Mesh(f)
2505+
24872506
# Calculate the bounding boxes for all cells by running a kernel
2488-
V = functionspace.VectorFunctionSpace(self, "DG", 0, dim=gdim)
2507+
V = functionspace.VectorFunctionSpace(mesh, "DG", 0, dim=gdim)
24892508
coords_min = function.Function(V, dtype=RealType)
24902509
coords_max = function.Function(V, dtype=RealType)
24912510

24922511
coords_min.dat.data.fill(np.inf)
24932512
coords_max.dat.data.fill(-np.inf)
24942513

24952514
if utils.complex_mode:
2496-
if not np.allclose(self.coordinates.dat.data_ro.imag, 0):
2515+
if not np.allclose(mesh.coordinates.dat.data_ro.imag, 0):
24972516
raise ValueError("Coordinate field has non-zero imaginary part")
2498-
coords = function.Function(self.coordinates.function_space(),
2499-
val=self.coordinates.dat.data_ro_with_halos.real.copy(),
2517+
coords = function.Function(mesh.coordinates.function_space(),
2518+
val=mesh.coordinates.dat.data_ro_with_halos.real.copy(),
25002519
dtype=RealType)
25012520
else:
2502-
coords = self.coordinates
2521+
coords = mesh.coordinates
25032522

2504-
cell_node_list = self.coordinates.function_space().cell_node_list
2523+
cell_node_list = mesh.coordinates.function_space().cell_node_list
25052524
_, nodes_per_cell = cell_node_list.shape
25062525

2507-
domain = "{{[d, i]: 0 <= d < {0} and 0 <= i < {1}}}".format(gdim, nodes_per_cell)
2526+
domain = f"{{[d, i]: 0 <= d < {gdim} and 0 <= i < {nodes_per_cell}}}"
25082527
instructions = """
25092528
for d, i
25102529
f_min[0, d] = fmin(f_min[0, d], f[i, d])
@@ -2518,21 +2537,51 @@ def spatial_index(self):
25182537

25192538
# Reorder bounding boxes according to the cell indices we use
25202539
column_list = V.cell_node_list.reshape(-1)
2521-
coords_min = self._order_data_by_cell_index(column_list, coords_min.dat.data_ro_with_halos)
2522-
coords_max = self._order_data_by_cell_index(column_list, coords_max.dat.data_ro_with_halos)
2540+
coords_min = mesh._order_data_by_cell_index(column_list, coords_min.dat.data_ro_with_halos)
2541+
coords_max = mesh._order_data_by_cell_index(column_list, coords_max.dat.data_ro_with_halos)
2542+
2543+
return coords_min, coords_max
2544+
2545+
@property
2546+
def spatial_index(self):
2547+
"""Builds spatial index from bounding box coordinates, expanding
2548+
the bounding box by the mesh tolerance.
2549+
2550+
Returns
2551+
-------
2552+
:class:`~.spatialindex.SpatialIndex` or None if the mesh is
2553+
one-dimensional.
2554+
2555+
Notes
2556+
-----
2557+
If this mesh has a :attr:`tolerance` property, which
2558+
should be a float, this tolerance is added to the extrema of the
2559+
spatial index so that points just outside the mesh, within tolerance,
2560+
can be found.
25232561
2562+
"""
2563+
if self.coordinates.dat.dat_version != self._saved_coordinate_dat_version:
2564+
if "bounding_box_coords" in self.__dict__:
2565+
del self.bounding_box_coords
2566+
else:
2567+
if self._spatial_index:
2568+
return self._spatial_index
25242569
# Change min and max to refer to an n-hypercube, where n is the
25252570
# geometric dimension of the mesh, centred on the midpoint of the
25262571
# bounding box. Its side length is the L1 diameter of the bounding box.
25272572
# This aids point evaluation on immersed manifolds and other cases
25282573
# where points may be just off the mesh but should be evaluated.
25292574
# TODO: This is perhaps unnecessary when we aren't in these special
25302575
# cases.
2531-
25322576
# We also push max and min out so we can find points on the boundary
25332577
# within the mesh tolerance.
25342578
# NOTE: getattr doesn't work here due to the inheritance games that are
25352579
# going on in getattr.
2580+
if self.bounding_box_coords is None:
2581+
# This happens in 1D meshes
2582+
return None
2583+
else:
2584+
coords_min, coords_max = self.bounding_box_coords
25362585
tolerance = self.tolerance if hasattr(self, "tolerance") else 0.0
25372586
coords_mid = (coords_max + coords_min)/2
25382587
d = np.max(coords_max - coords_min, axis=1)[:, None]
@@ -3360,11 +3409,13 @@ def VertexOnlyMesh(mesh, vertexcoords, reorder=None, missing_points_behaviour='e
33603409
_, pdim = vertexcoords.shape
33613410
if not np.isclose(np.sum(abs(vertexcoords.imag)), 0):
33623411
raise ValueError("Point coordinates must have zero imaginary part")
3363-
# Bendy meshes require a smarter bounding box algorithm at partition and
3364-
# (especially) cell level. Projecting coordinates to Bernstein may be
3365-
# sufficient.
3366-
if np.any(np.asarray(mesh.coordinates.function_space().ufl_element().degree()) > 1):
3367-
raise NotImplementedError("Only straight edged meshes are supported")
3412+
if (
3413+
ufl.cell.simplex(mesh.topological_dimension()) != mesh.ufl_cell()
3414+
and np.any(np.asarray(mesh.ufl_coordinate_element().degree()) > 1)
3415+
):
3416+
raise NotImplementedError(
3417+
"Cannot yet immerse a VertexOnlyMesh in non-simplicial higher-order meshes."
3418+
)
33683419
# Currently we take responsibility for locating the mesh cells in which the
33693420
# vertices lie.
33703421
#

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,7 @@ def make_high_order(m_low_order, degree):
4242
"unitsquare",
4343
"circlemanifold",
4444
"circlemanifold_to_high_order",
45-
pytest.param(
46-
"unitsquare_from_high_order",
47-
marks=pytest.mark.xfail(
48-
# CalledProcessError is so the parallel tests correctly xfail
49-
raises=(subprocess.CalledProcessError, NotImplementedError),
50-
reason="Cannot yet interpolate from high order meshes to other meshes.",
51-
),
52-
),
45+
"unitsquare_from_high_order",
5346
"unitsquare_to_high_order",
5447
"extrudedcube",
5548
"unitsquare_vfs",

tests/firedrake/regression/test_locate_cell.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import pytest
22
import numpy as np
33
from firedrake import *
4+
from functools import reduce
5+
from operator import mul
6+
7+
8+
def warp(x, p):
9+
return p * x * (2 * x - 1)
410

511

612
@pytest.fixture(scope="module", params=[False, True])
@@ -76,3 +82,67 @@ def test_locate_cells_ref_coords_and_dists(meshdata):
7682
fcells, ref_coords, l1_dists = m.locate_cells_ref_coords_and_dists(points[:2], cells_ignore=np.array([cells[:4], cells[1:5]]))
7783
assert fcells[0] == -1 or fcells[0] in cells[4:]
7884
assert fcells[1] == -1 or fcells[1] in cells[5:] or fcells[1] in cells[:1]
85+
86+
87+
def test_high_order_location():
88+
mesh = UnitSquareMesh(2, 2)
89+
V = VectorFunctionSpace(mesh, "CG", 3, variant="equispaced")
90+
f = Function(V)
91+
f.interpolate(mesh.coordinates)
92+
93+
warp_indices = np.where((f.dat.data[:, 0] > 0.0) & (f.dat.data[:, 0] < 0.5) & (f.dat.data[:, 1] == 0.0))[0]
94+
f.dat.data[warp_indices, 1] = warp(f.dat.data[warp_indices, 0], 5.0)
95+
mesh = Mesh(f)
96+
97+
# The point (0.25, -0.6) *is* in the mesh, but falls outside the Lagrange bounding box
98+
# The below used to return (None, None), but projecting to Bernstein coordinates
99+
# allows us to locate the cell.
100+
assert mesh.locate_cell([0.25, -0.6], tolerance=0.001) is not None
101+
# The point (0.25, -0.7) is outside the mesh, but inside the Bernstein bounding box.
102+
# This should return (None, None).
103+
assert mesh.locate_cell([0.25, -0.7], tolerance=0.001) is None
104+
105+
# Change mesh coordinates to check that the bounding box is recalculated
106+
mesh.coordinates.dat.data_wo[warp_indices, 1] = warp(mesh.coordinates.dat.data_ro[warp_indices, 0], 8.0)
107+
assert mesh.locate_cell([0.25, -0.6], tolerance=0.0001) is not None
108+
assert mesh.locate_cell([0.25, -0.7], tolerance=0.0001) is not None
109+
assert mesh.locate_cell([0.25, -0.95], tolerance=0.0001) is not None
110+
assert mesh.locate_cell([0.25, -1.05], tolerance=0.0001) is None
111+
112+
113+
def test_high_order_location_warped_interior_facet():
114+
# Here we bend an interior facet and check the right cell is located.
115+
mesh = UnitSquareMesh(2, 2)
116+
V = VectorFunctionSpace(mesh, "CG", 3, variant="equispaced")
117+
f = Function(V)
118+
f.interpolate(mesh.coordinates)
119+
120+
warp_indices = np.where((f.dat.data[:, 0] > 0.0) & (f.dat.data[:, 0] < 0.5) & np.isclose(f.dat.data[:, 1], 0.5))[0]
121+
f.dat.data[warp_indices, 1] += 0.1
122+
mesh = Mesh(f)
123+
124+
assert mesh.locate_cell([0.25, 0.605], tolerance=0.0001) == 1
125+
assert mesh.locate_cell([0.25, 0.62], tolerance=0.0001) == 3
126+
127+
128+
@pytest.mark.parallel([1, 3])
129+
def test_parallel_high_order_location():
130+
mesh = UnitSquareMesh(2, 2)
131+
V = VectorFunctionSpace(mesh, "CG", 3, variant="equispaced")
132+
f = Function(V)
133+
f.interpolate(mesh.coordinates)
134+
135+
warp_indices = np.where((f.dat.data[:, 0] > 0.0) & (f.dat.data[:, 0] < 0.5) & (f.dat.data[:, 1] == 0.0))[0]
136+
f.dat.data[warp_indices, 1] = warp(f.dat.data[warp_indices, 0], 5.0)
137+
138+
mesh = Mesh(f)
139+
V = FunctionSpace(mesh, "CG", 3)
140+
f = Function(V).interpolate(reduce(mul, SpatialCoordinate(mesh)))
141+
142+
vom = VertexOnlyMesh(mesh, [[0.25, -0.6]], tolerance=0.0001, redundant=False)
143+
P0DG = FunctionSpace(vom, "DG", 0)
144+
P0DG_io = FunctionSpace(vom.input_ordering, "DG", 0)
145+
f_at = assemble(interpolate(f, P0DG))
146+
f_at_correct_order = assemble(interpolate(f_at, P0DG_io))
147+
148+
assert np.allclose(f_at_correct_order.dat.data_ro, [-0.6 * 0.25], atol=0.002)

0 commit comments

Comments
 (0)