diff --git a/spec/draft/extensions/linear_algebra_functions.rst b/spec/draft/extensions/linear_algebra_functions.rst index 6759b2260..314ae0e28 100644 --- a/spec/draft/extensions/linear_algebra_functions.rst +++ b/spec/draft/extensions/linear_algebra_functions.rst @@ -98,6 +98,7 @@ A conforming implementation of this ``linalg`` extension must provide and suppor eigh eigvalsh inv + lu matmul matrix_norm matrix_power diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index b03f6eb63..4a422ccfd 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -263,6 +263,51 @@ def inv(x: array, /) -> array: """ +def lu(x: array, /) -> Tuple[array, array, array]: + """ + Returns the LU decomposition of a matrix (or a stack of matrices). + + The decomposition is: + + .. math:: x = PLU + + where :math:`P` is a permutation matrix, :math:`L` lower triangular with unit + diagonal elements, and :math:`U` upper triangular. + + Parameters + ---------- + x : array + input array having shape ``(..., M, N)`` and whose innermost two + dimensions form ``MxN`` matrices. Should have a floating-point data + type. + + Returns + ------- + out: Tuple[array, array, array] + a namedtuple ``(P, L, U)`` whose + + - first element must have the field name ``P`` and must be an array + of shape ``(M, M)``. + - second element must have the field name ``L`` and must be an array + of shape ``(M, K)``, where ``K == min(M, N)``. + - third element must have the field name ``U`` and must be an array + of shape ``(K, N)``. + + Notes + ----- + A correct decomposition of the prescribed shape must always be returned. + This can be achieved by the implementer using an LU decomposition with + partial pivoting algorithm. + + Note that the LU decomposition is usually not unique, hence different + implementations may return different numerical values for the same input + values. + + .. versionchanged:: 2023.12 + + """ + + def matmul(x1: array, x2: array, /) -> array: """Alias for :func:`~array_api.matmul`.""" @@ -832,6 +877,7 @@ def vector_norm( "eigh", "eigvalsh", "inv", + "lu", "matmul", "matrix_norm", "matrix_power",