Skip to content

Commit

Permalink
WIP: Add top_k compatibility
Browse files Browse the repository at this point in the history
This references the PR data-apis/array-api-tests#274.
  • Loading branch information
JuliaPoo committed Jun 26, 2024
1 parent 51daace commit 1f34630
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 4 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/array-api-tests-jax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Array API Tests (JAX)

on: [push, pull_request]

jobs:
array-api-tests-jax:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: jax
pytest-extra-args: -k top_k
3 changes: 2 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ jobs:
- name: Checkout array-api-tests
uses: actions/checkout@v4
with:
repository: data-apis/array-api-tests
repository: JuliaPoo/array-api-tests
ref: wip-topk-tests
submodules: 'true'
path: array-api-tests
- name: Set up Python ${{ matrix.python-version }}
Expand Down
25 changes: 24 additions & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,28 @@ def asarray(

return da.asarray(obj, dtype=dtype, **kwargs)


def top_k(
x: Array,
k: int,
/,
axis: Optional[int] = None,
*,
largest: bool = True,
) -> tuple[Array, Array]:

if not largest:
k = -k

# For now, perform the computation twice,
# since an equivalent to numpy's `take_along_axis`
# does not exist.
# See https://github.com/dask/dask/issues/3663.
args = da.argtopk(x, k, axis=axis).compute()
vals = da.topk(x, k, axis=axis).compute()
return vals, args


from dask.array import (
# Element wise aliases
arccos as acos,
Expand Down Expand Up @@ -178,6 +200,7 @@ def asarray(
'bitwise_right_shift', 'concat', 'pow',
'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type',
'top_k']

_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
27 changes: 27 additions & 0 deletions array_api_compat/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import jax
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Tuple

from ..common._typing import Array


def top_k(
x: Array,
k: int,
/,
axis: Optional[int] = None,
*,
largest: bool = True,
) -> Tuple[Array, Array]:

# `swapaxes` is used to implement
# the `axis` kwarg
x = jax.numpy.swapaxes(x, axis, -1)
vals, args = jax.lax.top_k(x, k)
vals = jax.numpy.swapaxes(vals, axis, -1)
args = jax.numpy.swapaxes(args, axis, -1)
return vals, args


__all__ = ['top_k']
103 changes: 102 additions & 1 deletion array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,107 @@
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)


def top_k(a, k, /, *, axis=-1, largest=True):
"""
Returns the ``k`` largest/smallest elements and corresponding
indices along the given ``axis``.
When ``axis`` is None, a flattened array is used.
If ``largest`` is false, then the ``k`` smallest elements are returned.
A tuple of ``(values, indices)`` is returned, where ``values`` and
``indices`` of the largest/smallest elements of each row of the input
array in the given ``axis``.
Parameters
----------
a: array_like
The source array
k: int
The number of largest/smallest elements to return. ``k`` must
be a positive integer and within indexable range specified by
``axis``.
axis: int, optional
Axis along which to find the largest/smallest elements.
The default is -1 (the last axis).
If None, a flattened array is used.
largest: bool, optional
If True, largest elements are returned. Otherwise the smallest
are returned.
Returns
-------
tuple_of_array: tuple
The output tuple of ``(topk_values, topk_indices)``, where
``topk_values`` are returned elements from the source array
(not necessarily in sorted order), and ``topk_indices`` are
the corresponding indices.
See Also
--------
argpartition : Indirect partition.
sort : Full sorting.
Notes
-----
The returned indices are not guaranteed to be sorted according to
the values. Furthermore, the returned indices are not guaranteed
to be the earliest/latest occurrence of the element. E.g.,
``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))``
rather than ``(array([3]), array([0]))`` or
``(array([3]), array([2]))``.
Warning: The treatment of ``np.nan`` in the input array is undefined.
Examples
--------
>>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]])
>>> np.top_k(a, 2)
(array([[4, 5],
[4, 5],
[4, 5]]),
array([[3, 4],
[1, 0],
[1, 2]]))
>>> np.top_k(a, 2, axis=0)
(array([[3, 4, 3, 2, 2],
[5, 4, 5, 4, 5]]),
array([[2, 1, 1, 1, 2],
[1, 2, 2, 0, 0]]))
>>> a.flatten()
array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2])
>>> np.top_k(a, 2, axis=None)
(array([5, 5]), array([ 5, 12]))
"""
if k <= 0:
raise ValueError(f'k(={k}) provided must be positive.')

positive_axis: int
_arr = np.asanyarray(a)
if axis is None:
arr = _arr.ravel()
positive_axis = 0
else:
arr = _arr
positive_axis = axis if axis > 0 else axis % arr.ndim

slice_start = (np.s_[:],) * positive_axis
if largest:
indices_array = np.argpartition(arr, -k, axis=axis)
slice = slice_start + (np.s_[-k:],)
topk_indices = indices_array[slice]
else:
indices_array = np.argpartition(arr, k-1, axis=axis)
slice = slice_start + (np.s_[:k],)
topk_indices = indices_array[slice]

topk_values = np.take_along_axis(arr, topk_indices, axis=axis)

return (topk_values, topk_indices)


def _supports_buffer_protocol(obj):
try:
memoryview(obj)
Expand Down Expand Up @@ -126,6 +227,6 @@ def asarray(
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
'bitwise_right_shift', 'concat', 'pow', 'top_k']

_all_ignore = ['np', 'get_xp']
4 changes: 3 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
axis = 0
return torch.index_select(x, axis, indices, **kwargs)

top_k = torch.topk

__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
Expand All @@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
'take']
'take', 'top_k']

_all_ignore = ['torch', 'get_xp']

0 comments on commit 1f34630

Please sign in to comment.