Skip to content

warn if points defining custom element are outside cell #912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions python/basix/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: MIT
"""Functions to get cell geometry information and manipulate cell types."""

import numpy as np
import numpy.typing as npt

from basix._basixcpp import CellType
Expand Down Expand Up @@ -72,7 +73,7 @@ def volume(celltype: CellType) -> float:
return _v(celltype)


def facet_jacobians(celltype: CellType) -> npt.ArrayLike:
def facet_jacobians(celltype: CellType) -> npt.NDArray:
"""Jacobians of the facets of a reference cell.

Args:
Expand All @@ -81,10 +82,10 @@ def facet_jacobians(celltype: CellType) -> npt.ArrayLike:
Returns:
Jacobians of the facets.
"""
return _fj(celltype)
return np.array(_fj(celltype))


def edge_jacobians(celltype: CellType) -> npt.ArrayLike:
def edge_jacobians(celltype: CellType) -> npt.NDArray:
"""Jacobians of the edges of a reference cell.

Args:
Expand All @@ -93,10 +94,10 @@ def edge_jacobians(celltype: CellType) -> npt.ArrayLike:
Returns:
Jacobians of the edges.
"""
return _ej(celltype)
return np.array(_ej(celltype))


def facet_normals(celltype: CellType) -> npt.ArrayLike:
def facet_normals(celltype: CellType) -> npt.NDArray:
"""Normals to the facets of a reference cell.

These normals will be oriented using the low-to-high ordering of the
Expand All @@ -108,7 +109,7 @@ def facet_normals(celltype: CellType) -> npt.ArrayLike:
Returns:
Normals to the facets.
"""
return _fn(celltype)
return np.array(_fn(celltype))


def facet_orientations(celltype: CellType) -> list[int]:
Expand All @@ -126,7 +127,7 @@ def facet_orientations(celltype: CellType) -> list[int]:
return _fo(celltype)


def facet_outward_normals(celltype: CellType) -> npt.ArrayLike:
def facet_outward_normals(celltype: CellType) -> npt.NDArray:
"""Normals to the facets of a reference cell.

These normals will be oriented to be pointing outwards.
Expand All @@ -137,10 +138,10 @@ def facet_outward_normals(celltype: CellType) -> npt.ArrayLike:
Returns:
Normals to the facets.
"""
return _fon(celltype)
return np.array(_fon(celltype))


def facet_reference_volumes(celltype: CellType) -> npt.ArrayLike:
def facet_reference_volumes(celltype: CellType) -> npt.NDArray:
"""Reference volumes of the facets of a reference cell.

Args:
Expand All @@ -149,10 +150,10 @@ def facet_reference_volumes(celltype: CellType) -> npt.ArrayLike:
Returns:
Reference volumes.
"""
return _frv(celltype)
return np.array(_frv(celltype))


def geometry(celltype: CellType) -> npt.ArrayLike:
def geometry(celltype: CellType) -> npt.NDArray:
"""Cell geometry.

Args:
Expand All @@ -161,7 +162,7 @@ def geometry(celltype: CellType) -> npt.ArrayLike:
Returns:
Vertices of the cell.
"""
return _geometry(celltype)
return np.array(_geometry(celltype))


def topology(celltype: CellType) -> list[list[list[int]]]:
Expand Down
23 changes: 22 additions & 1 deletion python/basix/finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Functions for creating finite elements."""

import typing
from warnings import warn

import numpy as np
import numpy.typing as npt
Expand All @@ -23,7 +24,7 @@
from basix._basixcpp import create_tp_element as _create_tp_element
from basix._basixcpp import tp_dof_ordering as _tp_dof_ordering
from basix._basixcpp import tp_factors as _tp_factors
from basix.cell import CellType
from basix.cell import CellType, geometry, topology, facet_outward_normals
from basix import MapType
from basix.polynomials import PolysetType
from basix.sobolev_spaces import SobolevSpace
Expand Down Expand Up @@ -628,6 +629,26 @@ def create_custom_element(
wcoeffs = np.dtype(dtype).type(wcoeffs) # type: ignore
x = [[np.dtype(dtype).type(j) for j in i] for i in x] # type: ignore
M = [[np.dtype(dtype).type(j) for j in i] for i in M] # type: ignore

# Check shape of x
tdim = len(topology(cell_type)) - 1
for i in x:
for j in i:
if j.shape[1] != tdim:
raise RuntimeError("x has a point with the wrong tdim")
if len(j.shape) != 2:
raise ValueError("x has the wrong dimension")

# Warn if points are not inside the cell
geo = geometry(cell_type)
top = topology(cell_type)
for points_i in x:
for points_j in points_i:
for p in points_j:
for facet, facet_normal in zip(top[tdim - 1], facet_outward_normals(cell_type)):
if abs(np.dot(p - geo[facet[0]], facet_normal)) > 0.001:
warn(f"Point {p} is not in cell", UserWarning)

if np.issubdtype(dtype, np.float32):
_create_custom_element = _create_custom_element_float32 # type: ignore
elif np.issubdtype(dtype, np.float64):
Expand Down
139 changes: 95 additions & 44 deletions test/test_custom_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,55 +290,60 @@ def test_create_lagrange1_quad():
create_lagrange1_quad()


def assert_failure(**kwargs):
"""Assert that the correct RuntimeError is thrown."""
try:
create_lagrange1_quad(**kwargs)
except RuntimeError as e:
if len(e.args) == 0:
raise e
if "dgesv" in e.args[0]:
raise e
return
with pytest.raises(RuntimeError):
pass


def test_wcoeffs_wrong_shape():
"""Test that a runtime error is thrown when wcoeffs is the wrong shape."""
assert_failure(wcoeffs=np.eye(3))
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of"):
create_lagrange1_quad(wcoeffs=np.eye(3))


def test_wcoeffs_too_few_cols():
"""Test that a runtime error is thrown when wcoeffs has too few columns."""
assert_failure(
wcoeffs=np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])
)
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of"):
create_lagrange1_quad(
wcoeffs=np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])
)


def test_wcoeffs_too_few_rows():
"""Test that a runtime error is thrown when wcoeffs has too few rows."""
assert_failure(
wcoeffs=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0]])
)
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of"):
create_lagrange1_quad(
wcoeffs=np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0]])
)


def test_wcoeffs_zero_row():
"""Test that a runtime error is thrown when wcoeffs has a row of zeros."""
assert_failure(
wcoeffs=np.array(
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
with pytest.raises(
RuntimeError, match="Cannot orthogonalise the rows of a matrix with incomplete row rank"
):
create_lagrange1_quad(
wcoeffs=np.array(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]
)
)
)


def test_wcoeffs_equal_rows():
"""Test that a runtime error is thrown when wcoeffs has two equal rows."""
assert_failure(
wcoeffs=np.array(
[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]
with pytest.raises(
RuntimeError, match="Cannot orthogonalise the rows of a matrix with incomplete row rank"
):
create_lagrange1_quad(
wcoeffs=np.array(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
]
)
)
)


def test_x_wrong_tdim():
Expand All @@ -355,7 +360,8 @@ def test_x_wrong_tdim():
[z],
[],
]
assert_failure(x=x)
with pytest.raises(RuntimeError, match="x has a point with the wrong tdim"):
create_lagrange1_quad(x=x)


def test_x_too_many_points():
Expand All @@ -372,7 +378,8 @@ def test_x_too_many_points():
[np.array([[0.5, 0.5]])],
[],
]
assert_failure(x=x)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(x=x)


def test_x_point_tdim_too_high():
Expand All @@ -384,7 +391,8 @@ def test_x_point_tdim_too_high():
[z],
[np.array([[1.0, 1.0]])],
]
assert_failure(x=x)
with pytest.raises(RuntimeError, match="x has the wrong number of entities"):
create_lagrange1_quad(x=x)


def test_x_wrong_entity_count():
Expand All @@ -401,7 +409,8 @@ def test_x_wrong_entity_count():
[z],
[],
]
assert_failure(x=x)
with pytest.raises(RuntimeError, match="x has the wrong number of entities"):
create_lagrange1_quad(x=x)


def test_M_wrong_value_size():
Expand All @@ -418,7 +427,8 @@ def test_M_wrong_value_size():
[z],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(M=M)


def test_M_too_many_points():
Expand All @@ -435,7 +445,8 @@ def test_M_too_many_points():
[np.array([[[[1.0]]]])],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of rows"):
create_lagrange1_quad(M=M)

z = np.zeros((0, 1, 0, 1))
M = [
Expand All @@ -449,7 +460,8 @@ def test_M_too_many_points():
[z],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(M=M)


def test_M_wrong_entities():
Expand All @@ -461,7 +473,8 @@ def test_M_wrong_entities():
[z],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(M=M)


def test_M_too_many_derivs():
Expand All @@ -478,7 +491,8 @@ def test_M_too_many_derivs():
[z],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(M=M)


def test_M_zero_row():
Expand All @@ -495,29 +509,66 @@ def test_M_zero_row():
[z],
[],
]
assert_failure(M=M)
with pytest.raises(RuntimeError, match="Dual matrix is singular"):
create_lagrange1_quad(M=M)


def test_wrong_value_shape():
"""Test that a runtime error is thrown when value shape is wrong."""
assert_failure(value_shape=[2])
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of columns"):
create_lagrange1_quad(value_shape=[2])


def test_wrong_cell_type():
"""Test that a runtime error is thrown when cell type is wrong."""
assert_failure(cell_type=CellType.hexahedron)
with pytest.raises(RuntimeError, match="x has a point with the wrong tdim"):
create_lagrange1_quad(cell_type=CellType.hexahedron)


def test_wrong_degree():
"""Test that a runtime error is thrown when degree is wrong."""
assert_failure(degree=0)
with pytest.raises(RuntimeError, match="wcoeffs has the wrong number of columns"):
create_lagrange1_quad(degree=0)


def test_wrong_discontinuous():
"""Test that a runtime error is thrown when discontinuous is wrong."""
assert_failure(discontinuous=True)
with pytest.raises(RuntimeError, match="Discontinuous element can only have interior DOFs"):
create_lagrange1_quad(discontinuous=True)


def test_wrong_interpolation_nderivs():
"""Test that a runtime error is thrown when number of interpolation derivatives is wrong."""
assert_failure(interpolation_nderivs=1)
with pytest.raises(RuntimeError, match="M has the wrong shape"):
create_lagrange1_quad(interpolation_nderivs=1)


@pytest.mark.parametrize("dim", range(4))
@pytest.mark.parametrize("component", range(3))
def test_point_outside_cell_gives_warning(dim, component):
"""Test the defining an element with a point outside the cell gives a warning."""
lagrange = basix.create_element(
basix.ElementFamily.P, CellType.tetrahedron, 4, basix.LagrangeVariant.gll_warped
)
wcoeffs = lagrange.wcoeffs

x = [[j.copy() for j in i] for i in lagrange.x]
M = lagrange.M

x[dim][0][0][component] = -1.0

with pytest.warns(UserWarning):
basix.create_custom_element(
CellType.tetrahedron,
[],
wcoeffs,
x,
M,
0,
basix.MapType.identity,
basix.SobolevSpace.H1,
False,
4,
4,
basix.PolysetType.standard,
)
Loading