Skip to content

Commit 7760823

Browse files
emilyfertigJAXopt authors
authored and
JAXopt authors
committed
Internal change
PiperOrigin-RevId: 542883908
1 parent 53f659f commit 7760823

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

jaxopt/_src/tree_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def tree_gram(a):
167167

168168
def tree_inf_norm(tree_x):
169169
"""Computes the infinity norm of a pytree."""
170-
leaf_inf_norm = tree_map(lambda x: jnp.max(jnp.abs(x)), tree_x)
171-
return tree_reduce(jnp.maximum, leaf_inf_norm)
170+
leaves_vec = tree_leaves(tree_map(jnp.ravel, tree_x))
171+
return jnp.max(jnp.abs(jnp.concatenate(leaves_vec)))
172172

173173

174174
def tree_where(cond, a, b):

tests/tree_util_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,22 @@ def test_tree_l2_norm(self):
129129
got = tree_util.tree_l2_norm(self.tree_A)
130130
self.assertAllClose(expected, got)
131131

132+
def test_tree_inf_norm(self):
133+
expected = jnp.max(jnp.abs(self.array_A))
134+
got = tree_util.tree_inf_norm(self.array_A)
135+
self.assertAllClose(expected, got)
136+
137+
expected = jnp.maximum(jnp.max(jnp.abs(self.tree_A[0])),
138+
jnp.max(jnp.abs(self.tree_A[1])))
139+
got = tree_util.tree_inf_norm(self.tree_A)
140+
self.assertAllClose(expected, got)
141+
142+
tree_a_with_empty_leaves = ({'a': self.tree_A, 'b': jnp.array([])},
143+
jnp.array([]))
144+
expected = tree_util.tree_inf_norm(self.tree_A)
145+
got = tree_util.tree_inf_norm(tree_a_with_empty_leaves)
146+
self.assertAllClose(expected, got)
147+
132148
def test_tree_zeros_like(self):
133149
tree = tree_util.tree_zeros_like(self.tree_A)
134150
self.assertAllClose(tree_util.tree_l2_norm(tree), 0.0)

0 commit comments

Comments
 (0)