File tree 2 files changed +18
-2
lines changed
2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -167,8 +167,8 @@ def tree_gram(a):
167
167
168
168
def tree_inf_norm (tree_x ):
169
169
"""Computes the infinity norm of a pytree."""
170
- leaf_inf_norm = tree_map (lambda x : jnp .max ( jnp . abs ( x )) , tree_x )
171
- return tree_reduce (jnp .maximum , leaf_inf_norm )
170
+ leaves_vec = tree_leaves ( tree_map (jnp .ravel , tree_x ) )
171
+ return jnp . max (jnp .abs ( jnp . concatenate ( leaves_vec )) )
172
172
173
173
174
174
def tree_where (cond , a , b ):
Original file line number Diff line number Diff line change @@ -129,6 +129,22 @@ def test_tree_l2_norm(self):
129
129
got = tree_util .tree_l2_norm (self .tree_A )
130
130
self .assertAllClose (expected , got )
131
131
132
+ def test_tree_inf_norm (self ):
133
+ expected = jnp .max (jnp .abs (self .array_A ))
134
+ got = tree_util .tree_inf_norm (self .array_A )
135
+ self .assertAllClose (expected , got )
136
+
137
+ expected = jnp .maximum (jnp .max (jnp .abs (self .tree_A [0 ])),
138
+ jnp .max (jnp .abs (self .tree_A [1 ])))
139
+ got = tree_util .tree_inf_norm (self .tree_A )
140
+ self .assertAllClose (expected , got )
141
+
142
+ tree_a_with_empty_leaves = ({'a' : self .tree_A , 'b' : jnp .array ([])},
143
+ jnp .array ([]))
144
+ expected = tree_util .tree_inf_norm (self .tree_A )
145
+ got = tree_util .tree_inf_norm (tree_a_with_empty_leaves )
146
+ self .assertAllClose (expected , got )
147
+
132
148
def test_tree_zeros_like (self ):
133
149
tree = tree_util .tree_zeros_like (self .tree_A )
134
150
self .assertAllClose (tree_util .tree_l2_norm (tree ), 0.0 )
You can’t perform that action at this time.
0 commit comments