Skip to content

Commit

Permalink
FIX test failing after latest jax release
Browse files Browse the repository at this point in the history
Workaround for jax-ml/jax#19713
  • Loading branch information
fabianp committed Feb 8, 2024
1 parent 54c238c commit ba042c5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
build/
dist/
venv/
_testing/

# Building the documentation
docs/_autosummary
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
else:
opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs)

params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))]
params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))]

state = self.variant(opt.init)(params)
updates, new_state = self.variant(opt.update)(grads, state, params)
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,21 @@ def test_multi_steps_skip_not_finite(self):
updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params)
self.assertEqual(int(opt_state.mini_step), 0)
params = update.apply_updates(params, updates)
np.testing.assert_array_equal(params['a'], -jnp.ones([]))
np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([])))

with self.subTest('test_inf_updates'):
updates, opt_state = opt_update(
dict(a=jnp.array(float('inf'))), opt_state, params)
self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step
params = update.apply_updates(params, updates)
np.testing.assert_array_equal(params['a'], -jnp.ones([]))
np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([])))

with self.subTest('test_nan_updates'):
updates, opt_state = opt_update(
dict(a=jnp.full([], float('nan'))), opt_state, params)
self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step
params = update.apply_updates(params, updates)
np.testing.assert_array_equal(params['a'], -jnp.ones([]))
np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([])))

with self.subTest('test_final_good_updates'):
updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params)
Expand All @@ -387,7 +387,7 @@ def test_multi_steps_skip_not_finite(self):
updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params)
self.assertEqual(int(opt_state.mini_step), 0)
params = update.apply_updates(params, updates)
np.testing.assert_array_equal(params['a'], -jnp.full([], 2.))
np.testing.assert_array_equal(params['a'], jnp.negative(jnp.full([], 2.)))


class MaskedTest(chex.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions optax/contrib/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
opt = opt_factory(**opt_kwargs)
opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs)

params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))]
params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))]
grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))]

state = self.variant(opt.init)(params)
updates, new_state = self.variant(opt.update)(grads, state, params)
Expand Down
4 changes: 4 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ set -xeuo pipefail

# Install deps in a virtual env.
rm -rf _testing
rm -rf .pytype
mkdir -p _testing
readonly VENV_DIR="$(mktemp -d -p `pwd`/_testing optax-env.XXXXXXXX)"
# in the unlikely case in which there was something in that directory
Expand Down Expand Up @@ -82,6 +83,9 @@ cd docs && make html
make doctest
cd ..

# cleanup
rm -rf _testing

set +u
deactivate
echo "All tests passed. Congrats!"

0 comments on commit ba042c5

Please sign in to comment.