Skip to content

torch take_along_axis does not support negative indices #360

@mdhaber

Description

@mdhaber

The standard states that take_along_axis must support negative indices as usual, but:

from array_api_compat import torch as xp
xp.take_along_axis(xp.asarray([1]), xp.asarray([-1]))
# RuntimeError: index -1 is out of bounds for dimension 0 with size 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions