-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add preliminary support for the draft 2024.12 version of the standard #82
Conversation
I don't know how to make this depend on API version so for now it's just there always.
This makes it so that it doesn't have a bunch of extra names on it, which it did as a module.
As far as I can tell, NumPy matches the standard specification, except for the fact that NumPy does not set a default value for axis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 5 out of 10 changed files in this pull request and generated 1 suggestion.
Files not reviewed (5)
- array_api_strict/_typing.py: Evaluated as low risk
- array_api_strict/init.py: Evaluated as low risk
- array_api_strict/tests/test_array_object.py: Evaluated as low risk
- array_api_strict/tests/test_elementwise_functions.py: Evaluated as low risk
- array_api_strict/_elementwise_functions.py: Evaluated as low risk
@@ -25,3 +26,14 @@ def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array: | |||
if x.device != indices.device: | |||
raise ValueError(f"Arrays from two different devices ({x.device} and {indices.device}) can not be combined.") | |||
return Array._new(np.take(x._array, indices._array, axis=axis), device=x.device) | |||
|
|||
@requires_api_version('2024.12') | |||
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function take_along_axis
should check that indices.dtype
is in _integer_dtypes
to ensure only integer indices are allowed.
Copilot is powered by AI, so mistakes are possible. Review output carefully before use.
Going to merge and release this. This is all hidden behind a flag and warning, so even though this is still fairly untested, I think it's OK let people start playing with it. I believe I have included every change to the draft standard since 2023.12. However, there are still some changes that are planned for 2024.12, but which aren't merged in the standard yet and which I haven't implemented here yet. The main one of interest is scalar support for array functions like |
No description provided.