diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 5ca2ccaf2..cb5228636 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -38,7 +38,12 @@ jobs: strategy: matrix: python-version: [3.7] - tf-version: ["2.4.*", "tf-nightly"] + # tf-nightly has some pip version conflicts, so can't be installed. + # Use only numbered TF as of now. + # tf-version: ["2.4.*", "tf-nightly"] + tf-version: ["2.4.*"] + # Which set of tests to run. + trax-test: ["lib", "research"] # Steps represent a sequence of tasks that will be executed as part of the job steps: @@ -54,7 +59,7 @@ jobs: python -m pip install -q -U setuptools numpy python -m pip install flake8 pytest if [[ ${{matrix.tf-version}} == "tf-nightly" ]]; then python -m pip install tf-nightly; else python -m pip install -q "tensorflow=="${{matrix.tf-version}}; fi - pip install -e .[test] + pip install -e .[tests] # # Lint with flake8 # - name: Lint with flake8 # run: | @@ -65,7 +70,7 @@ jobs: # Test out right now with only testing one directory. - name: Test with pytest run: | - pytest trax/data + TRAX_TEST=" ${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status diff --git a/.travis.yml b/.travis.yml index 59b77f9bf..0251cb069 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ python: - "3.6" env: global: - - TF_VERSION="2.3.*" + - TF_VERSION="2.4.*" matrix: - TRAX_TEST="lib" - TRAX_TEST="research" diff --git a/README.md b/README.md index 0df562a3a..863103b2c 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@ version](https://badge.fury.io/py/trax.svg)](https://badge.fury.io/py/trax) [![GitHub Issues](https://img.shields.io/github/issues/google/trax.svg)](https://github.com/google/trax/issues) +![GitHub Build](https://github.com/google/trax/actions/workflows/build.yaml/badge.svg) [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/trax-ml/community) [![Travis](https://img.shields.io/travis/google/trax.svg)](https://travis-ci.org/google/trax) - [Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google.com/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information. 1. **Run a pre-trained Transformer**: create a translator in a few lines of code diff --git a/oss_scripts/oss_tests.sh b/oss_scripts/oss_tests.sh index 23f149223..cbf11792f 100755 --- a/oss_scripts/oss_tests.sh +++ b/oss_scripts/oss_tests.sh @@ -32,9 +32,7 @@ function set_status() { } # Check env vars set -echo "${TF_VERSION:?}" && \ -echo "${TRAX_TEST:?}" && \ -echo "${TRAVIS_PYTHON_VERSION:?}" +echo "${TRAX_TEST:?}" set_status if [[ $STATUS -ne 0 ]] then @@ -83,6 +81,7 @@ then --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_grad_memory \ --deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_memory \ --ignore=trax/layers/initializers_test.py \ + --ignore=trax/layers/test_utils.py \ trax/layers set_status @@ -124,8 +123,9 @@ else set_status ## Trax2Keras - pytest trax/trax2keras_test.py - set_status + # TODO(afrozm): Make public again after TF 2.5 releases. + # pytest trax/trax2keras_test.py + # set_status # Check notebooks. diff --git a/trax/layers/combinators_test.py b/trax/layers/combinators_test.py index 72c024390..0b8940319 100644 --- a/trax/layers/combinators_test.py +++ b/trax/layers/combinators_test.py @@ -519,8 +519,11 @@ def some_layer(): self.assertEqual(output_shapes, [(3,), (5,), (2,)]) +BACKENDS = [fastmath.Backend.JAX] + + @parameterized.named_parameters( - ('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP)) + ('_' + b.value, b) for b in BACKENDS) class ScanTest(parameterized.TestCase): def _AddWithCarry(self): # pylint: disable=invalid-name diff --git a/trax/layers/research/efficient_attention_test.py b/trax/layers/research/efficient_attention_test.py index 6159ec203..7b89d9546 100644 --- a/trax/layers/research/efficient_attention_test.py +++ b/trax/layers/research/efficient_attention_test.py @@ -71,17 +71,6 @@ def test_lsh_self_attention(self): y = layer(x) self.assertEqual(y.shape, x.shape) - def test_lsh_self_attention_tf(self): - with fastmath.use_backend(fastmath.Backend.TFNP): - layer = efficient_attention.LSHSelfAttention( - n_heads=5, d_qk=7, d_v=17, causal=True, - chunk_len=8, n_chunks_before=1, n_chunks_after=0, - n_hashes=2, n_buckets=4, - use_reference_code=True, attention_dropout=0.0, mode='train') - x = np.ones((3, 32, 8)).astype(np.float32) - _, _ = layer.init(shapes.signature(x)) - y = layer(x) - self.assertEqual(y.shape, x.shape) def _run_forward_and_backward(self, model, inp, weights, state): def forward(inp, weights): diff --git a/trax/layers/reversible_test.py b/trax/layers/reversible_test.py index b157296f3..8142a37c0 100644 --- a/trax/layers/reversible_test.py +++ b/trax/layers/reversible_test.py @@ -24,7 +24,7 @@ import trax.layers as tl -BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] +BACKENDS = [fastmath.Backend.JAX] class ReversibleLayerTest(parameterized.TestCase): diff --git a/trax/layers/rnn_test.py b/trax/layers/rnn_test.py index 1aa5da0b3..dfb53df12 100644 --- a/trax/layers/rnn_test.py +++ b/trax/layers/rnn_test.py @@ -25,8 +25,11 @@ import trax.layers as tl +BACKENDS = [fastmath.Backend.JAX] + + @parameterized.named_parameters( - ('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP)) + ('_' + b.value, b) for b in BACKENDS) class RnnTest(parameterized.TestCase): def test_conv_gru_cell(self, backend): diff --git a/trax/models/reformer/reformer_test.py b/trax/models/reformer/reformer_test.py index 38d629559..a01ec91ef 100644 --- a/trax/models/reformer/reformer_test.py +++ b/trax/models/reformer/reformer_test.py @@ -30,7 +30,7 @@ from trax.models.reformer import reformer -BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] +BACKENDS = [fastmath.Backend.JAX] def short_name(b): diff --git a/trax/models/rnn_test.py b/trax/models/rnn_test.py index 0ba960061..9b4c8006a 100644 --- a/trax/models/rnn_test.py +++ b/trax/models/rnn_test.py @@ -24,9 +24,11 @@ from trax import shapes from trax.models import rnn +BACKENDS = [fastmath.Backend.JAX] + @parameterized.named_parameters( - ('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP)) + ('_' + b.value, b) for b in BACKENDS) class RNNTest(parameterized.TestCase): def test_rnnlm_forward_shape(self, backend): diff --git a/trax/optimizers/trainer_test.py b/trax/optimizers/trainer_test.py index 9b28259e2..863dc00b3 100644 --- a/trax/optimizers/trainer_test.py +++ b/trax/optimizers/trainer_test.py @@ -53,19 +53,6 @@ def test_run_simple_task(self): rng = fastmath.random.get_prng(0) trainer.one_step(labeled_batch, rng) - def test_run_simple_task_tfnp(self): - """Runs an accelerated optimizer on a simple task, TFNP backend.""" - with fastmath.use_backend(fastmath.Backend.TFNP): - inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch - targets_batch = np.pi * np.ones_like(inputs_batch) - labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) - loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) - loss_layer.init(labeled_batch) - optimizer = optimizers.Adam(.01) - optimizer.tree_init(loss_layer.weights) - trainer = optimizers.Trainer(loss_layer, optimizer) - rng = fastmath.random.get_prng(0) - trainer.one_step(labeled_batch, rng) def test_run_sharded_reformer2(self): """Runs Reformer2 with sharded weights (only on 2+-device systems).""" diff --git a/trax/supervised/trainer_lib_test.py b/trax/supervised/trainer_lib_test.py index f7463705b..934f27e7f 100644 --- a/trax/supervised/trainer_lib_test.py +++ b/trax/supervised/trainer_lib_test.py @@ -89,7 +89,8 @@ def input_stream_masked(n_devices): -BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP] +BACKENDS = [fastmath.Backend.JAX] +BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [('Simple', None)])] def short_name(b): @@ -229,11 +230,7 @@ def model_fn(mode='train'): @parameterized.named_parameters( ('_%s_%s_%s' % (short_name(backend), model_name, opt_name(opt)), # pylint: disable=g-complex-comprehension backend, model_name, opt) - for backend, configs in [ - (fastmath.Backend.JAX, [('Simple', None)]), - (fastmath.Backend.TFNP, [('Simple', None), - ('Resnet50', trax_opt.Momentum), - ('Transformer', trax_opt.Adam)])] + for backend, configs in BACKENDS_AND_CONFIGS for model_name, opt in configs) def test_train_eval_predict(self, backend, model_name, opt): self._test_train_eval_predict(backend, model_name, opt) @@ -533,37 +530,6 @@ def test_tf_xla_forced_compile(self): self._test_train_eval_predict('tf') fastmath.tf.set_tf_xla_forced_compile(old_flag) - def test_no_int32_or_uint32_returned(self): - """Tests that Trainer._jit_update_fn doesn't return int32 or uint32. - - TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled - computation to copy int32/uint32 outputs to CPU. This test makes sure that - won't happen. - """ - with fastmath.use_backend(fastmath.Backend.TFNP): - n_classes = 1001 - model_fn = functools.partial(models.Resnet50, - n_output_classes=n_classes) - inputs = _test_inputs(n_classes, input_shape=(224, 224, 3)) - trainer = trainer_lib.Trainer( - model=model_fn, - loss_fn=tl.WeightedCategoryCrossEntropy(), - optimizer=trax_opt.SM3, - lr_schedule=lr.multifactor(), - inputs=inputs, - ) - output_dir = self.create_tempdir().full_path - trainer.reset(output_dir) - trainer.train_epoch(1, 0) - # Those are the things returned by Trainer._jit_update_fn - arrays = (trainer._opt_state.weights, trainer._opt_state.slots, - trainer._model_state, trainer._rngs) - arrays = tf.nest.flatten(arrays) - for x in arrays: - if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or - x.dtype == jnp.uint32): - raise ValueError('Found an array of int32 or uint32: %s' % x) - class EpochsTest(absltest.TestCase): diff --git a/trax/supervised/training_test.py b/trax/supervised/training_test.py index a87cf5a21..cb4ae0721 100644 --- a/trax/supervised/training_test.py +++ b/trax/supervised/training_test.py @@ -52,15 +52,6 @@ def test_loop_no_eval_task(self): # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5) - def test_loop_no_eval_task_tfnp(self): - """Runs a training loop with no eval task(s), TFNP backend.""" - with fastmath.use_backend(fastmath.Backend.TFNP): - model = tl.Serial(tl.Dense(1)) - task = training.TrainTask( - _very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) - training_session = training.Loop(model, [task]) - # Loop should initialize and run successfully, even with no eval task. - training_session.run(n_steps=5) def test_loop_checkpoint_low_metric(self): """Runs a training loop that saves checkpoints for low metric values."""