-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple array formats #13
base: develop
Are you sure you want to change the base?
Conversation
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
Hi @corentincarton, @oiffrig, any thought on this? |
def sign(x, *args, **kwargs): | ||
"""Reimplement the sign function to handle NaNs. | ||
|
||
The problem is that torch.sign returns 0 for NaNs, but the array API | ||
standard requires NaNs to be propagated. | ||
""" | ||
x = _xp.asarray(x) | ||
r = _xp.sign(x, *args, **kwargs) | ||
r = _xp.asarray(r) | ||
r[_xp.isnan(x)] = _xp.nan | ||
return r |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what we discussed this seems like a good solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great.
I understand what your were saying about the tests now...
# avoid divided by zero warning | ||
err = np.seterr(divide="ignore", invalid="ignore") | ||
try: | ||
err = xp.seterr(divide="ignore", invalid="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this local to only numpy?
Could it make sense to check if numpy run func otherwise skip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created an empty implementation for the other backends so now we can use it just like this:
err = xp.seterr(divide="ignore", invalid="ignore")
But this could be misleading. So, the other solution is:
if array_api_compat.is_numpy_namespace(xp):
err = xp.seterr(divide="ignore", invalid="ignore")
else: | ||
qs = np.asarray(which) | ||
qs = xp.asarray(which) | ||
|
||
if method == "numpy_bulk": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we change this api and rename the methods? This doesn't make sense anymore now that we have the Array API. @oiffrig what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean. Happy to rename the methods to something that doesn't contain "numpy" (although we may want to keep a fallback. PProc uses the "sort" method, it should be kept (more memory-efficient than numpy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is: does it still make sense to call it numpy when the backend may not be numpy anymore? It could be confusing when users try to use the function with copy arrays for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's give official names that don't contain numpy (may be worth keeping the numpy names for backwards compatibility, don't need to appear in the docs though)
Implements #9
I converted all the modules to be array format agnostic but still allowing numbers as input:
All the methods (with one exception, see below) are working with the following array backends:
All the tests had to be heavily refactored.
Details
The code goes like this:
So the namespace is not passed as a kwarg but determined from the input arrays.
The tests now work with multiple array formats and numbers too:
Dependencies
array-api-compat
is now a required dependency.Outstanding problem with cupy
cupy.quantile()
behaves differently thannp.quantile()
when there are nans in the input array. For this reason the testtest_quantiles_nan()
is disabled for cupy! This problem has to be further investigated!Problems
np.polynomial.polynomial.polyval()
is used but not part of the array API standard. Nopolyval()
is available in torch.torch.sign()
returns 0 for nan. According to the array API standard it should return nan. Numpy behaves correctly.The following methods are used but not available in the array API standard. However, they are all available in numpy, cupy and torch.
np.percentile()
is not available in the array API standard. Torch only hasquantile()
, which is also in numpy/cupy.quantile()
is not in the array API standard.np.histogram2d()
is not available in the array API standard. Torch only hashistogramdd()
, which is also in numpy/cupy.histogramdd()
is not in the array API standard.The following methods are only available in numpy:
The following methods have a different name in the array API standard than in numpy.
fabs()
->abs()
power()
->pow()
On an array-api-compat namespace we can only call
pow()
andabs()
!Solutions
Modifications
++++++++++++++++
power
andfabs
withpow
andabs
, respectively.Patches
+++++++++++++++
The namespace returned by
array_api_compat.array_namespace()
is patched using methods fromutils.compute
. See:utils.array
for details.Note:
utils.compute.percentile
is implemented withquantile
.utils.compute.histogram2d
is implemented withhistogramdd
.utils.compute.seterr
does nothing and returns an empty dictnumpy:
-
np.polynomial.polynomial.polyval
is added aspolyval
cupy:
-
cupy.polynomial.polynomial.polyval
is added aspolyval
-
utils.compute.seterr
is added asseterr
torch:
-
utils.compute.polyval
is added aspolyval
-
utils.compute.percentile
is added aspercentile
-
utils.compute.histogram2d
is added ashistogram2d
-
utils.compute.seterr
is added asseterr
-
sign
is modified to treat nans correctlyother namespaces (not tested, maybe not needed)
-
utils.compute.polyval
is added aspolyval
-
utils.compute.percentile
is added aspercentile
-
utils.compute.histogram2d
is added ashistogram2d
-
utils.compute.seterr
is added asseterr
With this we can call
polyval
,percentile
andhistogram2d
on an array-namespace:TODO
np.seterr