Skip to content

Commit

Permalink
[TRAX] Disable some tf-numpy tests till TF 2.5 releases. Enable full …
Browse files Browse the repository at this point in the history
…testing on Github Actions!

PiperOrigin-RevId: 369590184
  • Loading branch information
afrozenator authored and copybara-github committed Apr 21, 2021
1 parent 2a251b6 commit 65378ce
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 85 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ jobs:
strategy:
matrix:
python-version: [3.7]
tf-version: ["2.4.*", "tf-nightly"]
# tf-nightly has some pip version conflicts, so can't be installed.
# Use only numbered TF as of now.
# tf-version: ["2.4.*", "tf-nightly"]
tf-version: ["2.4.*"]
# Which set of tests to run.
trax-test: ["lib", "research"]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand All @@ -54,7 +59,7 @@ jobs:
python -m pip install -q -U setuptools numpy
python -m pip install flake8 pytest
if [[ ${{matrix.tf-version}} == "tf-nightly" ]]; then python -m pip install tf-nightly; else python -m pip install -q "tensorflow=="${{matrix.tf-version}}; fi
pip install -e .[test]
pip install -e .[tests]
# # Lint with flake8
# - name: Lint with flake8
# run: |
Expand All @@ -65,7 +70,7 @@ jobs:
# Test out right now with only testing one directory.
- name: Test with pytest
run: |
pytest trax/data
TRAX_TEST=" ${{matrix.trax-test}}" ./oss_scripts/oss_tests.sh
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python:
- "3.6"
env:
global:
- TF_VERSION="2.3.*"
- TF_VERSION="2.4.*"
matrix:
- TRAX_TEST="lib"
- TRAX_TEST="research"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
version](https://badge.fury.io/py/trax.svg)](https://badge.fury.io/py/trax)
[![GitHub
Issues](https://img.shields.io/github/issues/google/trax.svg)](https://github.com/google/trax/issues)
![GitHub Build](https://github.com/google/trax/actions/workflows/build.yaml/badge.svg)
[![Contributions
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/trax-ml/community)
[![Travis](https://img.shields.io/travis/google/trax.svg)](https://travis-ci.org/google/trax)


[Trax](https://trax-ml.readthedocs.io/en/latest/) is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the [Google Brain team](https://research.google.com/teams/brain/). This notebook ([run it in colab](https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb)) shows how to use Trax and where you can find more information.

1. **Run a pre-trained Transformer**: create a translator in a few lines of code
Expand Down
10 changes: 5 additions & 5 deletions oss_scripts/oss_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ function set_status() {
}

# Check env vars set
echo "${TF_VERSION:?}" && \
echo "${TRAX_TEST:?}" && \
echo "${TRAVIS_PYTHON_VERSION:?}"
echo "${TRAX_TEST:?}"
set_status
if [[ $STATUS -ne 0 ]]
then
Expand Down Expand Up @@ -83,6 +81,7 @@ then
--deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_grad_memory \
--deselect=trax/layers/acceleration_test.py::AccelerationTest::test_chunk_memory \
--ignore=trax/layers/initializers_test.py \
--ignore=trax/layers/test_utils.py \
trax/layers
set_status

Expand Down Expand Up @@ -124,8 +123,9 @@ else
set_status

## Trax2Keras
pytest trax/trax2keras_test.py
set_status
# TODO(afrozm): Make public again after TF 2.5 releases.
# pytest trax/trax2keras_test.py
# set_status

# Check notebooks.

Expand Down
5 changes: 4 additions & 1 deletion trax/layers/combinators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,11 @@ def some_layer():
self.assertEqual(output_shapes, [(3,), (5,), (2,)])


BACKENDS = [fastmath.Backend.JAX]


@parameterized.named_parameters(
('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP))
('_' + b.value, b) for b in BACKENDS)
class ScanTest(parameterized.TestCase):

def _AddWithCarry(self): # pylint: disable=invalid-name
Expand Down
11 changes: 0 additions & 11 deletions trax/layers/research/efficient_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,6 @@ def test_lsh_self_attention(self):
y = layer(x)
self.assertEqual(y.shape, x.shape)

def test_lsh_self_attention_tf(self):
with fastmath.use_backend(fastmath.Backend.TFNP):
layer = efficient_attention.LSHSelfAttention(
n_heads=5, d_qk=7, d_v=17, causal=True,
chunk_len=8, n_chunks_before=1, n_chunks_after=0,
n_hashes=2, n_buckets=4,
use_reference_code=True, attention_dropout=0.0, mode='train')
x = np.ones((3, 32, 8)).astype(np.float32)
_, _ = layer.init(shapes.signature(x))
y = layer(x)
self.assertEqual(y.shape, x.shape)

def _run_forward_and_backward(self, model, inp, weights, state):
def forward(inp, weights):
Expand Down
2 changes: 1 addition & 1 deletion trax/layers/reversible_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import trax.layers as tl


BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP]
BACKENDS = [fastmath.Backend.JAX]


class ReversibleLayerTest(parameterized.TestCase):
Expand Down
5 changes: 4 additions & 1 deletion trax/layers/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import trax.layers as tl


BACKENDS = [fastmath.Backend.JAX]


@parameterized.named_parameters(
('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP))
('_' + b.value, b) for b in BACKENDS)
class RnnTest(parameterized.TestCase):

def test_conv_gru_cell(self, backend):
Expand Down
2 changes: 1 addition & 1 deletion trax/models/reformer/reformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from trax.models.reformer import reformer


BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP]
BACKENDS = [fastmath.Backend.JAX]


def short_name(b):
Expand Down
4 changes: 3 additions & 1 deletion trax/models/rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
from trax import shapes
from trax.models import rnn

BACKENDS = [fastmath.Backend.JAX]


@parameterized.named_parameters(
('_' + b.value, b) for b in (fastmath.Backend.JAX, fastmath.Backend.TFNP))
('_' + b.value, b) for b in BACKENDS)
class RNNTest(parameterized.TestCase):

def test_rnnlm_forward_shape(self, backend):
Expand Down
13 changes: 0 additions & 13 deletions trax/optimizers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,6 @@ def test_run_simple_task(self):
rng = fastmath.random.get_prng(0)
trainer.one_step(labeled_batch, rng)

def test_run_simple_task_tfnp(self):
"""Runs an accelerated optimizer on a simple task, TFNP backend."""
with fastmath.use_backend(fastmath.Backend.TFNP):
inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch
targets_batch = np.pi * np.ones_like(inputs_batch)
labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss())
loss_layer.init(labeled_batch)
optimizer = optimizers.Adam(.01)
optimizer.tree_init(loss_layer.weights)
trainer = optimizers.Trainer(loss_layer, optimizer)
rng = fastmath.random.get_prng(0)
trainer.one_step(labeled_batch, rng)

def test_run_sharded_reformer2(self):
"""Runs Reformer2 with sharded weights (only on 2+-device systems)."""
Expand Down
40 changes: 3 additions & 37 deletions trax/supervised/trainer_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def input_stream_masked(n_devices):



BACKENDS = [fastmath.Backend.JAX, fastmath.Backend.TFNP]
BACKENDS = [fastmath.Backend.JAX]
BACKENDS_AND_CONFIGS = [(fastmath.Backend.JAX, [('Simple', None)])]


def short_name(b):
Expand Down Expand Up @@ -229,11 +230,7 @@ def model_fn(mode='train'):
@parameterized.named_parameters(
('_%s_%s_%s' % (short_name(backend), model_name, opt_name(opt)), # pylint: disable=g-complex-comprehension
backend, model_name, opt)
for backend, configs in [
(fastmath.Backend.JAX, [('Simple', None)]),
(fastmath.Backend.TFNP, [('Simple', None),
('Resnet50', trax_opt.Momentum),
('Transformer', trax_opt.Adam)])]
for backend, configs in BACKENDS_AND_CONFIGS
for model_name, opt in configs)
def test_train_eval_predict(self, backend, model_name, opt):
self._test_train_eval_predict(backend, model_name, opt)
Expand Down Expand Up @@ -533,37 +530,6 @@ def test_tf_xla_forced_compile(self):
self._test_train_eval_predict('tf')
fastmath.tf.set_tf_xla_forced_compile(old_flag)

def test_no_int32_or_uint32_returned(self):
"""Tests that Trainer._jit_update_fn doesn't return int32 or uint32.
TF pins int32/uint32 tensors to CPU, which will cause XLA-forced-compiled
computation to copy int32/uint32 outputs to CPU. This test makes sure that
won't happen.
"""
with fastmath.use_backend(fastmath.Backend.TFNP):
n_classes = 1001
model_fn = functools.partial(models.Resnet50,
n_output_classes=n_classes)
inputs = _test_inputs(n_classes, input_shape=(224, 224, 3))
trainer = trainer_lib.Trainer(
model=model_fn,
loss_fn=tl.WeightedCategoryCrossEntropy(),
optimizer=trax_opt.SM3,
lr_schedule=lr.multifactor(),
inputs=inputs,
)
output_dir = self.create_tempdir().full_path
trainer.reset(output_dir)
trainer.train_epoch(1, 0)
# Those are the things returned by Trainer._jit_update_fn
arrays = (trainer._opt_state.weights, trainer._opt_state.slots,
trainer._model_state, trainer._rngs)
arrays = tf.nest.flatten(arrays)
for x in arrays:
if isinstance(x, jnp.ndarray) and (x.dtype == jnp.int32 or
x.dtype == jnp.uint32):
raise ValueError('Found an array of int32 or uint32: %s' % x)



class EpochsTest(absltest.TestCase):
Expand Down
9 changes: 0 additions & 9 deletions trax/supervised/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ def test_loop_no_eval_task(self):
# Loop should initialize and run successfully, even with no eval task.
training_session.run(n_steps=5)

def test_loop_no_eval_task_tfnp(self):
"""Runs a training loop with no eval task(s), TFNP backend."""
with fastmath.use_backend(fastmath.Backend.TFNP):
model = tl.Serial(tl.Dense(1))
task = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01))
training_session = training.Loop(model, [task])
# Loop should initialize and run successfully, even with no eval task.
training_session.run(n_steps=5)

def test_loop_checkpoint_low_metric(self):
"""Runs a training loop that saves checkpoints for low metric values."""
Expand Down

0 comments on commit 65378ce

Please sign in to comment.