diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml new file mode 100644 index 00000000..4e93c7db --- /dev/null +++ b/.github/workflows/array-api-tests-jax.yml @@ -0,0 +1,10 @@ +name: Array API Tests (JAX) + +on: [push, pull_request] + +jobs: + array-api-tests-jax: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: jax + pytest-extra-args: -k top_k diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6e709438..a6361754 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -50,7 +50,8 @@ jobs: - name: Checkout array-api-tests uses: actions/checkout@v4 with: - repository: data-apis/array-api-tests + repository: JuliaPoo/array-api-tests + ref: wip-topk-tests submodules: 'true' path: array-api-tests - name: Set up Python ${{ matrix.python-version }} diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d2aac8b2..9a4b897d 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -150,6 +150,28 @@ def asarray( return da.asarray(obj, dtype=dtype, **kwargs) + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> tuple[Array, Array]: + + if not largest: + k = -k + + # For now, perform the computation twice, + # since an equivalent to numpy's `take_along_axis` + # does not exist. + # See https://github.com/dask/dask/issues/3663. + args = da.argtopk(x, k, axis=axis).compute() + vals = da.topk(x, k, axis=axis).compute() + return vals, args + + from dask.array import ( # Element wise aliases arccos as acos, @@ -178,6 +200,7 @@ def asarray( 'bitwise_right_shift', 'concat', 'pow', 'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type'] + 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type', + 'top_k'] _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np'] diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py new file mode 100644 index 00000000..27762936 --- /dev/null +++ b/array_api_compat/jax/__init__.py @@ -0,0 +1,27 @@ +import jax +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Tuple + + from ..common._typing import Array + + +def top_k( + x: Array, + k: int, + /, + axis: Optional[int] = None, + *, + largest: bool = True, +) -> Tuple[Array, Array]: + + # `swapaxes` is used to implement + # the `axis` kwarg + x = jax.numpy.swapaxes(x, axis, -1) + vals, args = jax.lax.top_k(x, k) + vals = jax.numpy.swapaxes(vals, axis, -1) + args = jax.numpy.swapaxes(args, axis, -1) + return vals, args + + +__all__ = ['top_k'] \ No newline at end of file diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 70378716..34720ff0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,107 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) + +def top_k(a, k, /, *, axis=-1, largest=True): + """ + Returns the ``k`` largest/smallest elements and corresponding + indices along the given ``axis``. + + When ``axis`` is None, a flattened array is used. + + If ``largest`` is false, then the ``k`` smallest elements are returned. + + A tuple of ``(values, indices)`` is returned, where ``values`` and + ``indices`` of the largest/smallest elements of each row of the input + array in the given ``axis``. + + Parameters + ---------- + a: array_like + The source array + k: int + The number of largest/smallest elements to return. ``k`` must + be a positive integer and within indexable range specified by + ``axis``. + axis: int, optional + Axis along which to find the largest/smallest elements. + The default is -1 (the last axis). + If None, a flattened array is used. + largest: bool, optional + If True, largest elements are returned. Otherwise the smallest + are returned. + + Returns + ------- + tuple_of_array: tuple + The output tuple of ``(topk_values, topk_indices)``, where + ``topk_values`` are returned elements from the source array + (not necessarily in sorted order), and ``topk_indices`` are + the corresponding indices. + + See Also + -------- + argpartition : Indirect partition. + sort : Full sorting. + + Notes + ----- + The returned indices are not guaranteed to be sorted according to + the values. Furthermore, the returned indices are not guaranteed + to be the earliest/latest occurrence of the element. E.g., + ``np.top_k([3,3,3], 1)`` can return ``(array([3]), array([1]))`` + rather than ``(array([3]), array([0]))`` or + ``(array([3]), array([2]))``. + + Warning: The treatment of ``np.nan`` in the input array is undefined. + + Examples + -------- + >>> a = np.array([[1,2,3,4,5], [5,4,3,2,1], [3,4,5,1,2]]) + >>> np.top_k(a, 2) + (array([[4, 5], + [4, 5], + [4, 5]]), + array([[3, 4], + [1, 0], + [1, 2]])) + >>> np.top_k(a, 2, axis=0) + (array([[3, 4, 3, 2, 2], + [5, 4, 5, 4, 5]]), + array([[2, 1, 1, 1, 2], + [1, 2, 2, 0, 0]])) + >>> a.flatten() + array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 3, 4, 5, 1, 2]) + >>> np.top_k(a, 2, axis=None) + (array([5, 5]), array([ 5, 12])) + """ + if k <= 0: + raise ValueError(f'k(={k}) provided must be positive.') + + positive_axis: int + _arr = np.asanyarray(a) + if axis is None: + arr = _arr.ravel() + positive_axis = 0 + else: + arr = _arr + positive_axis = axis if axis > 0 else axis % arr.ndim + + slice_start = (np.s_[:],) * positive_axis + if largest: + indices_array = np.argpartition(arr, -k, axis=axis) + slice = slice_start + (np.s_[-k:],) + topk_indices = indices_array[slice] + else: + indices_array = np.argpartition(arr, k-1, axis=axis) + slice = slice_start + (np.s_[:k],) + topk_indices = indices_array[slice] + + topk_values = np.take_along_axis(arr, topk_indices, axis=axis) + + return (topk_values, topk_indices) + + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -126,6 +227,6 @@ def asarray( __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + 'bitwise_right_shift', 'concat', 'pow', 'top_k'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index fb53e0ee..603dc15e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) +top_k = torch.topk + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', @@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'take', 'top_k'] _all_ignore = ['torch', 'get_xp']