diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index f404113f..da01aa4f 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -92,6 +92,7 @@ Tree NamedTupleKey tree_add tree_add_scalar_mul + tree_batch_shape tree_cast tree_div tree_dtype @@ -126,6 +127,10 @@ Tree add and scalar multiply ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_add_scalar_mul +Tree batch reshaping +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_batch_shape + Tree cast ~~~~~~~~~ .. autofunction:: tree_cast diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index cc65d6fb..c3685b46 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -28,6 +28,7 @@ from optax.tree_utils._state_utils import tree_set from optax.tree_utils._tree_math import tree_add from optax.tree_utils._tree_math import tree_add_scalar_mul +from optax.tree_utils._tree_math import tree_batch_shape from optax.tree_utils._tree_math import tree_bias_correction from optax.tree_utils._tree_math import tree_clip from optax.tree_utils._tree_math import tree_conj diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index dbbef357..7169fc86 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -255,6 +255,22 @@ def tree_linf_norm(tree: Any) -> chex.Numeric: return tree_max(abs_tree) +def tree_batch_shape( + tree: Any, + shape: tuple[int, ...] = (), +): + """Add leading batch dimensions to each leaf of a pytree. + + Args: + tree: a pytree. + shape: a shape indicating what leading batch dimensions to add. + + Returns: + a pytree with the leading batch dimensions added. + """ + return jax.tree.map(lambda x: jnp.broadcast_to(x, shape + x.shape), tree) + + def tree_zeros_like( tree: Any, dtype: Optional[jax.typing.DTypeLike] = None,