Replies: 5 comments 13 replies
-
I don't think a stateful solution in JAX relying on side-effects of flattening is something that should be recommended to users. Relying on this kind of implementation detail is likely to fail in corner cases we're not thinking of, and is almost certain to cause issues in the future when that implementation is changed. |
Beta Was this translation helpful? Give feedback.
-
A fully Array API-compliant functional random sampling libraryPerhaps there's even a more radical option, although I haven't looked into it at all and it may be entirely infeasible: port JAX's random implementation to be Array API compliant, so that a stateless random number solution can be applied with any array backend. |
Beta Was this translation helpful? Give feedback.
-
Using the NumPy RNG itself as the "key"Idea: Adopt the stateless/functional API from JAX as a Random API and wrap NumPy's RNGs as they are with a compliant set of functions. In principle, this looks easy enough. Assume we wanted to implement JAX's normal() for a NumPy def normal(rng, shape=(), dtype=np.float64):
return rng.standard_normal(shape, dtype) To use the rng, subrng = random.split(rng)
val = random.normal(subrng) The problem lies in "splitting" a traditional RNG. While there is a method to obtain independent sub-generators, we don't really want to invoke that for every random number generation. So, in the simple use case above, we would probably want to return the same RNG twice: def split(rng):
"""
Simple split for standard use case: return the same
random number generator as a sub-generator.
"""
return rng, rng However, there are situations where multiple sub-generators are indeed what is wanted (e.g., parallel execution). So there also needs to be an implementation of def split(rng, num=2):
"""
Actual split for parallel use case: return independent
sub-generators.
"""
return rng.spawn(num) This is potentially more in line with The only real solution I see here is to have a separate function for the "sample some numbers" use case. The API could then work along these lines: state = ... # the random state; rng for numpy, key for JAX
# sample a random variate
state, sampler = random.get(state) # get an object that can do random sampling
val = random.normal(sampler) # use object to generate some numbers
# do parallel sampling
state, substates = random.spawn(state) # get child state
sample_in_parallel(substates) This would need better names for what I call All in all, I think this is doable. But I am not a huge fan on this API. |
Beta Was this translation helpful? Give feedback.
-
Implementing a stateful/class-based API using JAXIdea: Adopt the Again, this looks easy enough in principle. On the JAX side, all it needs is wrapping the class JRNG:
def __init__(self, key):
self.key = key
def standard_normal(self, size=(), dtype=float) -> Array:
self.key, key = jax.random.split(self.key)
return normal(key, size, dtype) The major problem here is that we cannot easily pass this stateful object into JAX's compiled functions. There is a workaround: treat @jax.tree_util.register_pytree_node_class
class JRNG:
...
def tree_flatten(self):
self.key, key = jax.random.split(self.key)
return (key,), None
@classmethod
def tree_unflatten(cls, aux_data, children):
key, = children
rng = object.__new__(cls)
rng.key = key
return rng This workaround has the drawback that passing the RNG into a compiled function has different output than passing it into the same non-compiled function. (For reasons that aren't entirely clear to me, the pytree is also flattened 4 times for each invocation, advancing internal state each time. But this could probably be prevented.) On the whole, I am not particularly worried about this issue. If the idea is to create a Random API partner for the Array API, there will not generally be compiled functions in the pipeline. And everything that is sufficiently low-level to be a compiled JAX function will not accept the I packaged this up as a proof-of-concept wrapper for JAX here: glass-dev/jrng. It is only meant to be an illustration of the concept. |
Beta Was this translation helpful? Give feedback.
-
Building a stateless/functional API on top of NumPy's existing bit generatorsNumPy currently implements the >>> from numpy_random_api import random
>>> key = random.key(42)
>>> key
PRNGKeyArray([3444837047, 2669555309, 2046530742, 3581440988],
dtype=uint32, impl='philox')
>>> key, subkey = random.split(key)
>>> key
PRNGKeyArray([3973757322, 369700608, 604115056, 607984076],
dtype=uint32, impl='philox')
>>> random.normal(subkey, 4)
array([-0.05883458, 0.6125753 , -1.29899843, 0.12702094]) Equivalent code using NumPy's random interface (where everything is inline so there is no state): >>> import numpy as np
>>> from numpy.random import Generator, Philox, SeedSequence
>>> key = SeedSequence(42).generate_state(4).view(np.uint32)
>>> key
array([3444837047, 2669555309, 2046530742, 3581440988], dtype=uint32)
>>> key, subkey = Philox(key=key.view(np.uint64)).random_raw(4).reshape(2, 2).view(np.uint32)
>>> key
array([3973757322, 369700608, 604115056, 607984076], dtype=uint32)
>>> Generator(Philox(key=subkey.view(np.uint64))).standard_normal(4)
array([-0.05883458, 0.6125753 , -1.29899843, 0.12702094]) I have written a simple proof-of-concept here: ntessore/numpy_random_api I have to say, this works very nicely for building a JAX-like API around NumPy's existing random framework. It's not even entirely clear to me that the approach is limited to the one existing counter-based bit generator; as far as I can tell, there is no fundamental reason why the "key" being passed around cannot be used to seed the traditional bit generators. But I also haven't checked very carefully. All in all this would probably be my favourite solution for a random API, except that I don't like the user experience for array-agnostic libraries built on top of the Array API (which is my line of business). Since the state is carried around explicitly, this approach requires teaching users to manually |
Beta Was this translation helpful? Give feedback.
-
I would like to pick up the discussion from #431. The discussion there revolved mostly around the "stateless" functional JAX-like API vs the "stateful" class-based NumPy-like API for random number generation.
A project of ours is adopting the Array API and requires random number generation. I would like to share what I found in my experiments to implement either approach in the respectively different backend. Can we map out the space of possible solutions? This would allows us to see if there is a way forward at this point, or if it's better to defer this further.
Edit: Split into individual comments.
Beta Was this translation helpful? Give feedback.
All reactions