Skip to content

Commit 8dab908

Browse files
committed
Add defaults for min_value and max_value to tree_utils.tree_clip.
1 parent c583e75 commit 8dab908

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

optax/tree_utils/_tree_math.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -323,18 +323,20 @@ def tree_full_like(
323323

324324
def tree_clip(
325325
tree: Any,
326-
min_value: Optional[jax.typing.ArrayLike],
327-
max_value: Optional[jax.typing.ArrayLike],
326+
min_value: Optional[jax.typing.ArrayLike] = None,
327+
max_value: Optional[jax.typing.ArrayLike] = None,
328328
) -> Any:
329329
"""Creates an identical tree where all tensors are clipped to `[min, max]`.
330330
331331
Args:
332332
tree: pytree.
333-
min_value: min value to clip all tensors to.
334-
max_value: max value to clip all tensors to.
333+
min_value: optional minimal value to clip all tensors to. If ``None``
334+
(default) then result will not be clipped to any minimum value.
335+
max_value: optional maximal value to clip all tensors to. If ``None``
336+
(default) then result will not be clipped to any maximum value.
335337
336338
Returns:
337-
an tree with the same structure as ``tree``.
339+
a tree with the same structure as ``tree``.
338340
339341
.. versionadded:: 0.2.3
340342
"""

0 commit comments

Comments
 (0)