Skip to content

Commit

Permalink
Add sample_streams
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 9, 2024
1 parent c9a9b7a commit b4f5f75
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
inverse_softplus, matrix_dot_product, matrix_vector_mul,
outer_product, softplus)
from ._src.partial import Partial
from ._src.rng import RngStream, create_streams, fork_streams
from ._src.rng import RngStream, create_streams, fork_streams, sample_streams
from ._src.shims import custom_jvp, custom_jvp_method, custom_vjp, custom_vjp_method, hessian, jit
from ._src.testing import (assert_tree_allclose, get_relative_test_string, get_test_string,
tree_allclose)
Expand Down Expand Up @@ -116,6 +116,7 @@
'register_graph_as_nnx_node',
'replace_cotangent',
'result_type',
'sample_streams',
'scale_cotangent',
'softplus',
'tree_allclose',
Expand Down
16 changes: 8 additions & 8 deletions tjax/_src/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def key(self) -> KeyArray:
return key


def fork_streams(rngs: Mapping[str, RngStream],
samples: int | None = None
) -> Mapping[str, KeyArray]:
if samples is None:
return {name: stream.key() for name, stream in rngs.items()}
return {name: split(stream.key(), samples) for name, stream in rngs.items()}


def create_streams(keys: Mapping[str, KeyArray]) -> Mapping[str, RngStream]:
return {name: RngStream(key) for name, key in keys.items()}


def sample_streams(rngs: Mapping[str, RngStream]) -> Mapping[str, KeyArray]:
return {name: stream.key() for name, stream in rngs.items()}


def fork_streams(rngs: Mapping[str, RngStream], samples: int) -> Mapping[str, KeyArray]:
return {name: split(stream.key(), samples) for name, stream in rngs.items()}

0 comments on commit b4f5f75

Please sign in to comment.