From 47fca2f40c716702b64b66bdb10de79e6851011c Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Thu, 18 May 2023 20:29:36 +0200 Subject: [PATCH] Add a `linalg.lu` function for the LU decomposition Only the default (partial pivoting) algorithm that is implemented in all libraries and for all devices is added here. gh-627 has details on the no-pivoting case, but it's not universally supported and the only reason to add it would be that it's more performant in some cases where users know it will be numerically stable. Such an addition can be done in the future, but it seems like a potentially large amount of work for implementers for limited gain. Closes gh-627 --- .../extensions/linear_algebra_functions.rst | 1 + src/array_api_stubs/_draft/linalg.py | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/spec/draft/extensions/linear_algebra_functions.rst b/spec/draft/extensions/linear_algebra_functions.rst index 6759b2260..314ae0e28 100644 --- a/spec/draft/extensions/linear_algebra_functions.rst +++ b/spec/draft/extensions/linear_algebra_functions.rst @@ -98,6 +98,7 @@ A conforming implementation of this ``linalg`` extension must provide and suppor eigh eigvalsh inv + lu matmul matrix_norm matrix_power diff --git a/src/array_api_stubs/_draft/linalg.py b/src/array_api_stubs/_draft/linalg.py index b03f6eb63..4a422ccfd 100644 --- a/src/array_api_stubs/_draft/linalg.py +++ b/src/array_api_stubs/_draft/linalg.py @@ -263,6 +263,51 @@ def inv(x: array, /) -> array: """ +def lu(x: array, /) -> Tuple[array, array, array]: + """ + Returns the LU decomposition of a matrix (or a stack of matrices). + + The decomposition is: + + .. math:: x = PLU + + where :math:`P` is a permutation matrix, :math:`L` lower triangular with unit + diagonal elements, and :math:`U` upper triangular. + + Parameters + ---------- + x : array + input array having shape ``(..., M, N)`` and whose innermost two + dimensions form ``MxN`` matrices. Should have a floating-point data + type. + + Returns + ------- + out: Tuple[array, array, array] + a namedtuple ``(P, L, U)`` whose + + - first element must have the field name ``P`` and must be an array + of shape ``(M, M)``. + - second element must have the field name ``L`` and must be an array + of shape ``(M, K)``, where ``K == min(M, N)``. + - third element must have the field name ``U`` and must be an array + of shape ``(K, N)``. + + Notes + ----- + A correct decomposition of the prescribed shape must always be returned. + This can be achieved by the implementer using an LU decomposition with + partial pivoting algorithm. + + Note that the LU decomposition is usually not unique, hence different + implementations may return different numerical values for the same input + values. + + .. versionchanged:: 2023.12 + + """ + + def matmul(x1: array, x2: array, /) -> array: """Alias for :func:`~array_api.matmul`.""" @@ -832,6 +877,7 @@ def vector_norm( "eigh", "eigvalsh", "inv", + "lu", "matmul", "matrix_norm", "matrix_power",