Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optax.tree_utils.tree_batch_shape. #1161

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Tree
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_batch_shape
tree_cast
tree_div
tree_dtype
Expand Down Expand Up @@ -126,6 +127,10 @@ Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul

Tree batch reshaping
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_batch_shape

Tree cast
~~~~~~~~~
.. autofunction:: tree_cast
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from optax.tree_utils._state_utils import tree_set
from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_batch_shape
from optax.tree_utils._tree_math import tree_bias_correction
from optax.tree_utils._tree_math import tree_clip
from optax.tree_utils._tree_math import tree_conj
Expand Down
16 changes: 16 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ def tree_linf_norm(tree: Any) -> chex.Numeric:
return tree_max(abs_tree)


def tree_batch_shape(
tree: Any,
shape: tuple[int, ...] = (),
):
"""Add leading batch dimensions to each leaf of a pytree.

Args:
tree: a pytree.
shape: a shape indicating what leading batch dimensions to add.

Returns:
a pytree with the leading batch dimensions added.
"""
return jax.tree.map(lambda x: jnp.broadcast_to(x, shape + x.shape), tree)


def tree_zeros_like(
tree: Any,
dtype: Optional[jax.typing.DTypeLike] = None,
Expand Down