Skip to content

Commit

Permalink
Merge pull request #1161 from carlosgmartin:tree_random_like_batch_shape
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722553687
  • Loading branch information
OptaxDev committed Feb 3, 2025
2 parents c51fbd5 + 3fe2b27 commit b51d9a8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Tree
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_batch_shape
tree_cast
tree_div
tree_dtype
Expand Down Expand Up @@ -126,6 +127,10 @@ Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul

Tree batch reshaping
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_batch_shape

Tree cast
~~~~~~~~~
.. autofunction:: tree_cast
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optax.tree_utils._state_utils import tree_set
from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_batch_shape
from optax.tree_utils._tree_math import tree_bias_correction
from optax.tree_utils._tree_math import tree_clip
from optax.tree_utils._tree_math import tree_conj
Expand Down
16 changes: 16 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ def tree_linf_norm(tree: Any) -> chex.Numeric:
return tree_max(abs_tree)


def tree_batch_shape(
tree: Any,
shape: tuple[int, ...] = (),
):
"""Add leading batch dimensions to each leaf of a pytree.
Args:
tree: a pytree.
shape: a shape indicating what leading batch dimensions to add.
Returns:
a pytree with the leading batch dimensions added.
"""
return jax.tree.map(lambda x: jnp.broadcast_to(x, shape + x.shape), tree)


def tree_zeros_like(
tree: Any,
dtype: Optional[jax.typing.DTypeLike] = None,
Expand Down

0 comments on commit b51d9a8

Please sign in to comment.