Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way to create a diagonal array? #733

Closed
NeilGirdhar opened this issue Jan 21, 2024 · 6 comments
Closed

Is there a way to create a diagonal array? #733

NeilGirdhar opened this issue Jan 21, 2024 · 6 comments

Comments

@NeilGirdhar
Copy link

NeilGirdhar commented Jan 21, 2024

Something like:

def create_diagonal(m: NumpyRealArray) -> NumpyRealArray:
    """A vectorized version of diagonal.

    Args:
        m: Has shape (*k, n)
    Returns: Array with shape (*k, n, n) and the elements of m on the diagonals.
    """
    indices = (..., *np.diag_indices(m.shape[-1]))
    retval = np.zeros((*m.shape, m.shape[-1]), dtype=m.dtype)
    retval[indices] = m
    return retval

I noticed that the array API has no way to do this in either batched or unbatched mode?

@NeilGirdhar
Copy link
Author

(Probably niche.)

@rgommers
Copy link
Member

It is fairly niche, but still an interesting question. It may be in the category of "should be implementable given the primitives in the standard". And I think it is, given there's __setitem__ and reshape:

>>> def create_diagonal(values):
...     """The 2-D version only, should be generalizable"""
...     n = values.shape[0]
...     x = xp.zeros(n**2, dtype=values.dtype)
...     x[::n+1] = values
...     return xp.reshape(x, (n, n))
...
>>> values = xp.asarray([0, 1, 2, 3, 9.5])
>>> create_diagonal(values)
array([[0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. ],
       [0. , 0. , 2. , 0. , 0. ],
       [0. , 0. , 0. , 3. , 0. ],
       [0. , 0. , 0. , 0. , 9.5]])

It's vectorized, so this should be fine performance-wise too. To check:

>>> values = xp.asarray([0, 1, 2, 3, 9.5] * 1000)
>>> values.shape
(5000,)
>>> %timeit create_diagonal(values)
4.62 ms ± 69.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit np.diag(values)
4.61 ms ± 43.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

np.diag is two functions into one (array creation or diagonal extraction), so I'm reasonable happy with the create_diagonal above I think.

@rgommers
Copy link
Member

Also, gh-403 contains the following comment:

On the plus side, it's good to see diag retired in favor of diagonal (especially so in the age of batched matrices).

(that is linalg.diagonal)

@NeilGirdhar
Copy link
Author

Thank you very much for taking the time to write this!

I also agree with you about linalg.diagonal and love the new simplicity in the Array API.

@rgommers
Copy link
Member

In gh-668 there's a very basic start to a guide to answer such questions. This probably fits in there, although it's not a one-liner translation.

The other thing I had in mind is that it'd be nice to have reusable pure Python functions like this in a separate package - a la array-api-compat but with a different scope (APIs not in the standard). There'll be more of these (e.g., here is a portable cov: https://github.com/scipy/scipy/blob/ecda31227b9fcd5a00e12a1c23ede0aeb217c510/scipy/_lib/_array_api.py#L317)

@lucascolley
Copy link
Member

+1 on a separate package rather than expanding the scope of array-api-compat.

A related point is Robert's suggestion from the scipy linalg discussion:

I think it might still behoove us to have xp.linalg API-compatible versions that use scipy.linalg implementations when given ndarrays and xp.linalg implementations otherwise. Then other parts of scipy that are really concerned about keeping everything xp-native can use that and be sure that we're not going through a "convert to ndarray" codepath. I'm happy with it being scipy-internal, though it might be useful for other projects that already have scipy as a dependency, so maybe data-apis/array-api-compat would be a good place for it, too.

Related in the sense that I think it's beyond the scope of array-api-compat.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants