diff --git a/cpp/basix/math.h b/cpp/basix/math.h index 18fe21142..a8e2fdb75 100644 --- a/cpp/basix/math.h +++ b/cpp/basix/math.h @@ -310,9 +310,9 @@ template void dot(const U& A, const V& B, W&& C) { assert(A.extent(1) == B.extent(0)); - assert(C.extent(0) == C.extent(0)); + assert(C.extent(0) == A.extent(0)); assert(C.extent(1) == B.extent(1)); - if (A.extent(0) * B.extent(1) * A.extent(1) < 4096) + if (A.extent(0) * B.extent(1) * A.extent(1) < 512) { std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0); for (std::size_t i = 0; i < A.extent(0); ++i) diff --git a/test/test_create.py b/test/test_create.py index 9fdb2e744..f4ab0e712 100644 --- a/test/test_create.py +++ b/test/test_create.py @@ -82,3 +82,8 @@ def test_create_element(cell, degree, family, variant): except RuntimeError as e: if len(e.args) == 0 or "dgesv" in e.args[0]: raise e + + +def test_create_high_degree_lagrange(): + basix.create_element( + basix.ElementFamily.P, basix.CellType.hexahedron, 7, basix.LagrangeVariant.gll_isaac)