From 8e2c6eb11c78040c3fcf69eb00bc9568e1acaf9f Mon Sep 17 00:00:00 2001 From: Trax Team Date: Mon, 7 Oct 2019 21:55:40 -0700 Subject: [PATCH] Project import generated by Copybara. FolderOrigin-RevId: /google/src/cloud/afrozm/trax-release-4/google3/.. --- AUTHORS | 7 + CONTRIBUTING.md | 37 + ISSUE_TEMPLATE.md | 30 + LICENSE | 202 +++ README.md | 65 + docs/walkthrough.md | 65 + pylintrc | 221 +++ setup.py | 60 + trax/README.mdE | 65 + trax/__init__.py | 15 + trax/backend.py | 308 ++++ trax/backend_test.py | 75 + trax/configs/mlp_mnist.gin | 49 + trax/configs/mlp_mnist.ginE | 35 + .../position_lookup_transformer_copy.gin | 68 + .../position_lookup_transformer_copy.ginE | 54 + trax/configs/reformer_base_sweep.yaml | 16 + trax/configs/reformer_base_sweep.yamlE | 2 + trax/configs/reformer_enwik8.gin | 125 ++ trax/configs/reformer_enwik8.ginE | 115 ++ trax/configs/reformer_hash_sweep.yaml | 15 + trax/configs/reformer_hash_sweep.yamlE | 1 + trax/configs/reformer_imagenet64.gin | 124 ++ trax/configs/reformer_imagenet64.ginE | 114 ++ trax/configs/reformer_large_sweep.yaml | 17 + trax/configs/reformer_large_sweep.yamlE | 3 + trax/configs/resnet50_imagenet_8gb.gin | 58 + trax/configs/resnet50_imagenet_8gb.ginE | 44 + .../configs/resnet50_imagenet_8gb_testing.gin | 58 + .../resnet50_imagenet_8gb_testing.ginE | 44 + trax/configs/transformer_big_lm1b_8gb.gin | 68 + trax/configs/transformer_big_lm1b_8gb.ginE | 54 + trax/configs/transformer_copy.gin | 91 ++ trax/configs/transformer_copy.ginE | 77 + trax/configs/transformer_imdb_8gb.gin | 68 + trax/configs/transformer_imdb_8gb.ginE | 54 + trax/configs/transformer_lm1b_16gb.gin | 141 ++ trax/configs/transformer_lm1b_16gb.ginE | 127 ++ trax/configs/transformer_lm1b_8gb.gin | 68 + trax/configs/transformer_lm1b_8gb.ginE | 54 + trax/configs/transformer_lm1b_8gb_testing.gin | 68 + .../configs/transformer_lm1b_8gb_testing.ginE | 54 + trax/configs/transformer_lm_wmt_ende_16gb.gin | 148 ++ .../configs/transformer_lm_wmt_ende_16gb.ginE | 134 ++ trax/configs/transformer_lm_wmt_ende_8gb.gin | 78 + trax/configs/transformer_lm_wmt_ende_8gb.ginE | 64 + trax/configs/transformer_ptb_16gb.gin | 142 ++ trax/configs/transformer_ptb_16gb.ginE | 128 ++ ...former_wmt_ende_16gb_adafactor_testing.gin | 73 + ...ormer_wmt_ende_16gb_adafactor_testing.ginE | 59 + .../transformer_wmt_ende_8gb_adafactor.gin | 73 + .../transformer_wmt_ende_8gb_adafactor.ginE | 59 + .../configs/transformer_wmt_ende_8gb_adam.gin | 70 + .../transformer_wmt_ende_8gb_adam.ginE | 56 + trax/configs/transformer_wmt_ende_8gb_sm3.gin | 67 + .../configs/transformer_wmt_ende_8gb_sm3.ginE | 53 + trax/configs/wide_resnet_cifar10_8gb.gin | 95 ++ trax/configs/wide_resnet_cifar10_8gb.ginE | 81 + trax/history.py | 78 + trax/inputs.py | 648 ++++++++ trax/inputs_test.py | 72 + trax/jaxboard.py | 350 +++++ trax/layers/README.md | 60 + trax/layers/__init__.py | 62 + trax/layers/attention.py | 1355 +++++++++++++++++ trax/layers/attention_test.py | 106 ++ trax/layers/base.py | 664 ++++++++ trax/layers/base_test.py | 90 ++ trax/layers/combinators.py | 539 +++++++ trax/layers/combinators_test.py | 116 ++ trax/layers/convolution.py | 126 ++ trax/layers/convolution_test.py | 53 + trax/layers/core.py | 269 ++++ trax/layers/core_test.py | 125 ++ trax/layers/initializers.py | 173 +++ trax/layers/initializers_test.py | 83 + trax/layers/intro.ipynb | 834 ++++++++++ trax/layers/intro.ipynbE | 834 ++++++++++ trax/layers/metrics.py | 124 ++ trax/layers/metrics_test.py | 90 ++ trax/layers/normalization.py | 137 ++ trax/layers/normalization_test.py | 69 + trax/layers/pooling.py | 44 + trax/layers/pooling_test.py | 36 + trax/layers/reversible.py | 126 ++ trax/layers/reversible_test.py | 37 + trax/layers/rnn.py | 128 ++ trax/layers/rnn_test.py | 45 + trax/learning_rate.py | 264 ++++ trax/learning_rate_test.py | 119 ++ trax/models/__init__.py | 51 + trax/models/atari_cnn.py | 79 + trax/models/atari_cnn_test.py | 65 + trax/models/mlp.py | 38 + trax/models/mlp_test.py | 39 + trax/models/neural_gpu.py | 82 + trax/models/neural_gpu_test.py | 39 + trax/models/research/__init__.py | 15 + .../research/position_lookup_transformer.py | 341 +++++ trax/models/research/reformer.py | 531 +++++++ trax/models/research/reformer_test.py | 104 ++ trax/models/resnet.py | 169 ++ trax/models/resnet_test.py | 45 + trax/models/transformer.py | 395 +++++ trax/models/transformer_test.py | 104 ++ trax/notebooks/trax_demo_iclr2019.ipynb | 854 +++++++++++ trax/optimizers/__init__.py | 37 + trax/optimizers/base.py | 465 ++++++ trax/rl/__init__.py | 52 + trax/rl/base_trainer.py | 141 ++ trax/rl/base_trainer_test.py | 144 ++ trax/rl/configs/acrobot.gin | 44 + trax/rl/configs/acrobot.ginE | 30 + trax/rl/configs/acrobot_transformer.gin | 48 + trax/rl/configs/acrobot_transformer.ginE | 34 + trax/rl/configs/atari.gin | 44 + trax/rl/configs/atari.ginE | 30 + trax/rl/configs/atari_regression_test.gin | 44 + trax/rl/configs/atari_regression_test.ginE | 30 + ...nline_tune_transformer_imagenet64_16gb.gin | 119 ++ ...line_tune_transformer_imagenet64_16gb.ginE | 105 ++ .../env_online_tune_transformer_lm1b_16gb.gin | 113 ++ ...env_online_tune_transformer_lm1b_16gb.ginE | 99 ++ ...line_tune_transformer_lm_wmt_ende_16gb.gin | 110 ++ ...ine_tune_transformer_lm_wmt_ende_16gb.ginE | 96 ++ .../env_online_tune_transformer_ptb_16gb.gin | 108 ++ .../env_online_tune_transformer_ptb_16gb.ginE | 94 ++ ...nv_online_tune_wide_resnet_cifar10_8gb.gin | 66 + ...v_online_tune_wide_resnet_cifar10_8gb.ginE | 52 + trax/rl/configs/ppo_online_tune.gin | 51 + trax/rl/configs/ppo_online_tune.ginE | 37 + .../ppo_online_tune_wide_resnet_cifar10.gin | 93 ++ .../ppo_online_tune_wide_resnet_cifar10.ginE | 79 + trax/rl/configs/simple_online_tune.gin | 109 ++ trax/rl/configs/simple_online_tune.ginE | 95 ++ .../configs/simple_online_tune_serialized.gin | 114 ++ .../simple_online_tune_serialized.ginE | 100 ++ trax/rl/envs/__init__.py | 34 + trax/rl/envs/async_trajectory_collector.py | 197 +++ .../rl/envs/async_trajectory_collector_lib.py | 195 +++ .../async_trajectory_collector_lib_test.py | 64 + trax/rl/envs/fake_env.py | 68 + trax/rl/envs/fake_env_test.py | 64 + trax/rl/envs/online_tune.py | 57 + trax/rl/envs/online_tune_env.py | 232 +++ trax/rl/envs/online_tune_env_test.py | 151 ++ trax/rl/envs/online_tune_test.py | 110 ++ trax/rl/online_tune.py | 115 ++ trax/rl/online_tune_test.py | 175 +++ trax/rl/ppo.py | 971 ++++++++++++ trax/rl/ppo_test.py | 643 ++++++++ trax/rl/ppo_trainer.py | 844 ++++++++++ trax/rl/ppo_trainer_test.py | 306 ++++ trax/rl/serialization_utils.py | 184 +++ trax/rl/serialization_utils_test.py | 168 ++ trax/rl/simple.py | 236 +++ trax/rl/simple_test.py | 304 ++++ trax/rl/simple_trainer.py | 341 +++++ trax/rl/simple_trainer_test.py | 96 ++ trax/rl/simulated_env_problem.py | 499 ++++++ trax/rl/simulated_env_problem_test.py | 292 ++++ trax/rl/space_serializer.py | 216 +++ trax/rl/space_serializer_test.py | 161 ++ trax/rl/trainers.py | 37 + trax/rl_trainer.py | 209 +++ trax/trainer.py | 135 ++ trax/trainer_lib.py | 904 +++++++++++ trax/trainer_lib_test.py | 267 ++++ trax/utils.py | 43 + 169 files changed, 26304 insertions(+) create mode 100644 AUTHORS create mode 100644 CONTRIBUTING.md create mode 100644 ISSUE_TEMPLATE.md create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docs/walkthrough.md create mode 100644 pylintrc create mode 100644 setup.py create mode 100644 trax/README.mdE create mode 100644 trax/__init__.py create mode 100644 trax/backend.py create mode 100644 trax/backend_test.py create mode 100644 trax/configs/mlp_mnist.gin create mode 100644 trax/configs/mlp_mnist.ginE create mode 100644 trax/configs/position_lookup_transformer_copy.gin create mode 100644 trax/configs/position_lookup_transformer_copy.ginE create mode 100644 trax/configs/reformer_base_sweep.yaml create mode 100644 trax/configs/reformer_base_sweep.yamlE create mode 100644 trax/configs/reformer_enwik8.gin create mode 100644 trax/configs/reformer_enwik8.ginE create mode 100644 trax/configs/reformer_hash_sweep.yaml create mode 100644 trax/configs/reformer_hash_sweep.yamlE create mode 100644 trax/configs/reformer_imagenet64.gin create mode 100644 trax/configs/reformer_imagenet64.ginE create mode 100644 trax/configs/reformer_large_sweep.yaml create mode 100644 trax/configs/reformer_large_sweep.yamlE create mode 100644 trax/configs/resnet50_imagenet_8gb.gin create mode 100644 trax/configs/resnet50_imagenet_8gb.ginE create mode 100644 trax/configs/resnet50_imagenet_8gb_testing.gin create mode 100644 trax/configs/resnet50_imagenet_8gb_testing.ginE create mode 100644 trax/configs/transformer_big_lm1b_8gb.gin create mode 100644 trax/configs/transformer_big_lm1b_8gb.ginE create mode 100644 trax/configs/transformer_copy.gin create mode 100644 trax/configs/transformer_copy.ginE create mode 100644 trax/configs/transformer_imdb_8gb.gin create mode 100644 trax/configs/transformer_imdb_8gb.ginE create mode 100644 trax/configs/transformer_lm1b_16gb.gin create mode 100644 trax/configs/transformer_lm1b_16gb.ginE create mode 100644 trax/configs/transformer_lm1b_8gb.gin create mode 100644 trax/configs/transformer_lm1b_8gb.ginE create mode 100644 trax/configs/transformer_lm1b_8gb_testing.gin create mode 100644 trax/configs/transformer_lm1b_8gb_testing.ginE create mode 100644 trax/configs/transformer_lm_wmt_ende_16gb.gin create mode 100644 trax/configs/transformer_lm_wmt_ende_16gb.ginE create mode 100644 trax/configs/transformer_lm_wmt_ende_8gb.gin create mode 100644 trax/configs/transformer_lm_wmt_ende_8gb.ginE create mode 100644 trax/configs/transformer_ptb_16gb.gin create mode 100644 trax/configs/transformer_ptb_16gb.ginE create mode 100644 trax/configs/transformer_wmt_ende_16gb_adafactor_testing.gin create mode 100644 trax/configs/transformer_wmt_ende_16gb_adafactor_testing.ginE create mode 100644 trax/configs/transformer_wmt_ende_8gb_adafactor.gin create mode 100644 trax/configs/transformer_wmt_ende_8gb_adafactor.ginE create mode 100644 trax/configs/transformer_wmt_ende_8gb_adam.gin create mode 100644 trax/configs/transformer_wmt_ende_8gb_adam.ginE create mode 100644 trax/configs/transformer_wmt_ende_8gb_sm3.gin create mode 100644 trax/configs/transformer_wmt_ende_8gb_sm3.ginE create mode 100644 trax/configs/wide_resnet_cifar10_8gb.gin create mode 100644 trax/configs/wide_resnet_cifar10_8gb.ginE create mode 100644 trax/history.py create mode 100644 trax/inputs.py create mode 100644 trax/inputs_test.py create mode 100644 trax/jaxboard.py create mode 100644 trax/layers/README.md create mode 100644 trax/layers/__init__.py create mode 100644 trax/layers/attention.py create mode 100644 trax/layers/attention_test.py create mode 100644 trax/layers/base.py create mode 100644 trax/layers/base_test.py create mode 100644 trax/layers/combinators.py create mode 100644 trax/layers/combinators_test.py create mode 100644 trax/layers/convolution.py create mode 100644 trax/layers/convolution_test.py create mode 100644 trax/layers/core.py create mode 100644 trax/layers/core_test.py create mode 100644 trax/layers/initializers.py create mode 100644 trax/layers/initializers_test.py create mode 100644 trax/layers/intro.ipynb create mode 100644 trax/layers/intro.ipynbE create mode 100644 trax/layers/metrics.py create mode 100644 trax/layers/metrics_test.py create mode 100644 trax/layers/normalization.py create mode 100644 trax/layers/normalization_test.py create mode 100644 trax/layers/pooling.py create mode 100644 trax/layers/pooling_test.py create mode 100644 trax/layers/reversible.py create mode 100644 trax/layers/reversible_test.py create mode 100644 trax/layers/rnn.py create mode 100644 trax/layers/rnn_test.py create mode 100644 trax/learning_rate.py create mode 100644 trax/learning_rate_test.py create mode 100644 trax/models/__init__.py create mode 100644 trax/models/atari_cnn.py create mode 100644 trax/models/atari_cnn_test.py create mode 100644 trax/models/mlp.py create mode 100644 trax/models/mlp_test.py create mode 100644 trax/models/neural_gpu.py create mode 100644 trax/models/neural_gpu_test.py create mode 100644 trax/models/research/__init__.py create mode 100644 trax/models/research/position_lookup_transformer.py create mode 100644 trax/models/research/reformer.py create mode 100644 trax/models/research/reformer_test.py create mode 100644 trax/models/resnet.py create mode 100644 trax/models/resnet_test.py create mode 100644 trax/models/transformer.py create mode 100644 trax/models/transformer_test.py create mode 100644 trax/notebooks/trax_demo_iclr2019.ipynb create mode 100644 trax/optimizers/__init__.py create mode 100644 trax/optimizers/base.py create mode 100644 trax/rl/__init__.py create mode 100644 trax/rl/base_trainer.py create mode 100644 trax/rl/base_trainer_test.py create mode 100644 trax/rl/configs/acrobot.gin create mode 100644 trax/rl/configs/acrobot.ginE create mode 100644 trax/rl/configs/acrobot_transformer.gin create mode 100644 trax/rl/configs/acrobot_transformer.ginE create mode 100644 trax/rl/configs/atari.gin create mode 100644 trax/rl/configs/atari.ginE create mode 100644 trax/rl/configs/atari_regression_test.gin create mode 100644 trax/rl/configs/atari_regression_test.ginE create mode 100644 trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.gin create mode 100644 trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.ginE create mode 100644 trax/rl/configs/env_online_tune_transformer_lm1b_16gb.gin create mode 100644 trax/rl/configs/env_online_tune_transformer_lm1b_16gb.ginE create mode 100644 trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.gin create mode 100644 trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.ginE create mode 100644 trax/rl/configs/env_online_tune_transformer_ptb_16gb.gin create mode 100644 trax/rl/configs/env_online_tune_transformer_ptb_16gb.ginE create mode 100644 trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.gin create mode 100644 trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.ginE create mode 100644 trax/rl/configs/ppo_online_tune.gin create mode 100644 trax/rl/configs/ppo_online_tune.ginE create mode 100644 trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.gin create mode 100644 trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.ginE create mode 100644 trax/rl/configs/simple_online_tune.gin create mode 100644 trax/rl/configs/simple_online_tune.ginE create mode 100644 trax/rl/configs/simple_online_tune_serialized.gin create mode 100644 trax/rl/configs/simple_online_tune_serialized.ginE create mode 100644 trax/rl/envs/__init__.py create mode 100644 trax/rl/envs/async_trajectory_collector.py create mode 100644 trax/rl/envs/async_trajectory_collector_lib.py create mode 100644 trax/rl/envs/async_trajectory_collector_lib_test.py create mode 100644 trax/rl/envs/fake_env.py create mode 100644 trax/rl/envs/fake_env_test.py create mode 100644 trax/rl/envs/online_tune.py create mode 100644 trax/rl/envs/online_tune_env.py create mode 100644 trax/rl/envs/online_tune_env_test.py create mode 100644 trax/rl/envs/online_tune_test.py create mode 100644 trax/rl/online_tune.py create mode 100644 trax/rl/online_tune_test.py create mode 100644 trax/rl/ppo.py create mode 100644 trax/rl/ppo_test.py create mode 100644 trax/rl/ppo_trainer.py create mode 100644 trax/rl/ppo_trainer_test.py create mode 100644 trax/rl/serialization_utils.py create mode 100644 trax/rl/serialization_utils_test.py create mode 100644 trax/rl/simple.py create mode 100644 trax/rl/simple_test.py create mode 100644 trax/rl/simple_trainer.py create mode 100644 trax/rl/simple_trainer_test.py create mode 100644 trax/rl/simulated_env_problem.py create mode 100644 trax/rl/simulated_env_problem_test.py create mode 100644 trax/rl/space_serializer.py create mode 100644 trax/rl/space_serializer_test.py create mode 100644 trax/rl/trainers.py create mode 100644 trax/rl_trainer.py create mode 100644 trax/trainer.py create mode 100644 trax/trainer_lib.py create mode 100644 trax/trainer_lib_test.py create mode 100644 trax/utils.py diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..3207e455e --- /dev/null +++ b/AUTHORS @@ -0,0 +1,7 @@ +# This is the list of Trax authors for copyright purposes. +# +# This does not necessarily list everyone who has contributed code, since in +# some cases, their employer may be the copyright holder. To see the full list +# of contributors, see the revision history in source control. + +Google Inc. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..70c24403d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,37 @@ +# How to Contribute + +# Issues + +* Please tag your issue with `bug`, `feature request`, or `question` to help us + effectively respond. +* Please include the versions of JAX or Tensorflow you are running. +* Please provide the command line you ran as well as the log output. + +# Pull Requests + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows +[Google's Open Source Community Guidelines](https://opensource.google/conduct/). diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md new file mode 100644 index 000000000..670674b35 --- /dev/null +++ b/ISSUE_TEMPLATE.md @@ -0,0 +1,30 @@ +### Description + +... + +### Environment information + +``` +OS: + +$ pip freeze | grep tensor +# your output here + +$ pip freeze | grep jax +# your output here + +$ python -V +# your output here +``` + +### For bugs: reproduction and error logs + +``` +# Steps to reproduce: +... +``` + +``` +# Error logs: +... +``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..7efb88c5f --- /dev/null +++ b/README.md @@ -0,0 +1,65 @@ +## `trax`: Train Neural Nets with JAX + +![train tracks](https://images.pexels.com/photos/461772/pexels-photo-461772.jpeg?dl&fit=crop&crop=entropy&w=640&h=426) + +### `trax`: T2T Radically Simpler with JAX + +*Why?* Because T2T has gotten too complex. We are simplifying the main code too, +but we wanted to try a more radical step. So you can write code as in pure +NumPy and debug directly. So you can easily pinpoint each line where things +happen and understand each function. But we also want it to run fast on +accelerators, and that's possible with [JAX](https://github.com/google/jax). + +*Status:* preview; things work: models train, checkpoints are saved, TensorBoard +has summaries, you can decode. But we are changing a lot every day for now. +Please let us know what we should add, delete, keep, change. We plan to move +the best parts into core JAX. + +*Entrypoints:* + +* Script: `trainer.py` +* Main library entrypoint: `trax.train` + +### Examples + +#### Example Colab + +See our example constructing language models from scratch in a GPU-backed colab notebook at +[Trax Demo](https://colab.research.google.com/github/google/trax/blob/master/trax/notebooks/trax_demo_iclr2019.ipynb) + +#### MLP on MNIST + + +``` +python -m trax.trainer \ + --dataset=mnist \ + --model=MLP \ + --config="train.train_steps=1000" +``` + +#### Resnet50 on Imagenet + + +``` +python -m trax.trainer \ + --config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin +``` + +#### TransformerDecoder on LM1B + + +``` +python -m trax.trainer \ + --config_file=$PWD/trax/configs/transformer_lm1b_8gb.gin +``` + +### How `trax` differs from T2T + +* Configuration is done with [`gin`](https://github.com/google/gin-config). + `trainer.py` takes `--config_file` as well as `--config` for file overrides. +* Models are defined with [`stax`](https://github.com/google/jax/blob/master/jax/experimental/stax.py) in + `models/`. They are made gin-configurable in `models/__init__.py`. +* Datasets are simple iterators over batches. Datasets from + [`tensorflow/datasets`](https://github.com/tensorflow/datasets) + and [`tensor2tensor`](https://github.com/tensorflow/tensor2tensor) + are built-in and can be addressed by name. diff --git a/docs/walkthrough.md b/docs/walkthrough.md new file mode 100644 index 000000000..7efb88c5f --- /dev/null +++ b/docs/walkthrough.md @@ -0,0 +1,65 @@ +## `trax`: Train Neural Nets with JAX + +![train tracks](https://images.pexels.com/photos/461772/pexels-photo-461772.jpeg?dl&fit=crop&crop=entropy&w=640&h=426) + +### `trax`: T2T Radically Simpler with JAX + +*Why?* Because T2T has gotten too complex. We are simplifying the main code too, +but we wanted to try a more radical step. So you can write code as in pure +NumPy and debug directly. So you can easily pinpoint each line where things +happen and understand each function. But we also want it to run fast on +accelerators, and that's possible with [JAX](https://github.com/google/jax). + +*Status:* preview; things work: models train, checkpoints are saved, TensorBoard +has summaries, you can decode. But we are changing a lot every day for now. +Please let us know what we should add, delete, keep, change. We plan to move +the best parts into core JAX. + +*Entrypoints:* + +* Script: `trainer.py` +* Main library entrypoint: `trax.train` + +### Examples + +#### Example Colab + +See our example constructing language models from scratch in a GPU-backed colab notebook at +[Trax Demo](https://colab.research.google.com/github/google/trax/blob/master/trax/notebooks/trax_demo_iclr2019.ipynb) + +#### MLP on MNIST + + +``` +python -m trax.trainer \ + --dataset=mnist \ + --model=MLP \ + --config="train.train_steps=1000" +``` + +#### Resnet50 on Imagenet + + +``` +python -m trax.trainer \ + --config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin +``` + +#### TransformerDecoder on LM1B + + +``` +python -m trax.trainer \ + --config_file=$PWD/trax/configs/transformer_lm1b_8gb.gin +``` + +### How `trax` differs from T2T + +* Configuration is done with [`gin`](https://github.com/google/gin-config). + `trainer.py` takes `--config_file` as well as `--config` for file overrides. +* Models are defined with [`stax`](https://github.com/google/jax/blob/master/jax/experimental/stax.py) in + `models/`. They are made gin-configurable in `models/__init__.py`. +* Datasets are simple iterators over batches. Datasets from + [`tensorflow/datasets`](https://github.com/tensorflow/datasets) + and [`tensor2tensor`](https://github.com/tensorflow/tensor2tensor) + are built-in and can be addressed by name. diff --git a/pylintrc b/pylintrc new file mode 100644 index 000000000..ab45e0220 --- /dev/null +++ b/pylintrc @@ -0,0 +1,221 @@ + + +[MASTER] + +# Pickle collected data for later comparisons. +persistent=no + +# Set the cache size for astng objects. +cache-size=500 + +# Ignore Py3 files +ignore=get_references_web.py,get_references_web_single_group.py + + +[REPORTS] + +# Set the output format. +# output-format=sorted-text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages. +reports=no + +# Disable the report(s) with the given id(s). +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Error message template (continued on second line) +msg-template={msg_id}:{line:3} {obj}: {msg} [{symbol}] + + +[MESSAGES CONTROL] +# List of checkers and warnings to enable. +enable=indexing-exception,old-raise-syntax + +# List of checkers and warnings to disable. +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,file-ignored,multiple-imports,c-extension-no-member,trailing-newlines,unsubscriptable-object,misplaced-comparison-constant,no-member,abstract-method,no-else-return,missing-docstring,wrong-import-order,protected-access,inconsistent-return-statements,invalid-unary-operand-type,import-error,no-name-in-module,arguments-differ,not-context-manager,unused-argument + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# Regular expression which should only match the name +# of functions or classes which do not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Min length in lines of a function that requires a docstring. +docstring-min-length=10 + +# Regular expression which should only match correct module names. The +# leading underscore is sanctioned for private modules by Google's style +# guide. +# +# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover +# requirements of Python's module system. +module-rgx=^(_?[a-z][a-z0-9_]*)|__init__$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class attribute +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + + +# Regular expression which should only match correct method names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. 'exempt' indicates a name which is +# consistent with all naming styles. +method-rgx=(?x) + ^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase + |tearDownTestCase|setupSelf|tearDownClass|setUpClass + |(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next) + |(?P_{0,2}[A-Z][a-zA-Z0-9_]*) + |(?P_{0,2}[a-z][a-z0-9_]*))$ + + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=input,apply,reduce + +# List of decorators that define properties, such as abc.abstractproperty. +property-classes=abc.abstractproperty + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching names used for dummy variables (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# "class_" is also a valid for the first argument to a class method. +valid-classmethod-first-arg=cls,class_ + + +[EXCEPTIONS] + +overgeneral-exceptions=StandardError,Exception,BaseException + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# Regexp for a line that is allowed to be longer than the limit. +# This "ignore" regex is today composed of several independent parts: +# (1) Long import lines +# (2) URLs in comments or pydocs. Detecting URLs by regex is a hard problem and +# no amount of tweaking will make a perfect regex AFAICT. This one is a good +# compromise. +# (3) Constant string literals at the start of files don't need to be broken +# across lines. Allowing long paths and urls to be on a single +# line. Also requires that the string not be a triplequoted string. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |^\s*(\#\ )??$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + ) + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. We differ from PEP8's normal 4 spaces. +indent-string=' ' + +# Do not warn about multiple statements on a single line for constructs like +# if test: stmt +single-line-if-stmt=y + +# Make sure : in dicts and trailing commas are checked for whitespace. +no-space-check= + + +[LOGGING] + +# Add logging modules. +logging-modules=logging,absl.logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint + + +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError + + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..3c6a1df2a --- /dev/null +++ b/setup.py @@ -0,0 +1,60 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Install trax.""" + +from setuptools import find_packages +from setuptools import setup + +setup( + name='trax', + version='1.0.0', + description='Trax', + author='Google Inc.', + author_email='no-reply@google.com', + url='http://github.com/google/trax', + license='Apache 2.0', + packages=find_packages(), + install_requires=[ + 'gin-config', + 'gym', + 'numpy', + 'scipy', + 'six', + 'jax', + 'jaxlib', + 'tensorflow-datasets', + 'absl-py', + ], + extras_require={ + 'tensorflow': ['tensorflow>=1.14.0'], + 'tensorflow_gpu': ['tensorflow-gpu>=1.14.0'], + 'tests': [ + 'attrs', + 'pytest', + 'mock', + 'pylint', + 'jupyter', + 'matplotlib', + ], + }, + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], + keywords='tensorflow machine learning jax', +) diff --git a/trax/README.mdE b/trax/README.mdE new file mode 100644 index 000000000..459b9b916 --- /dev/null +++ b/trax/README.mdE @@ -0,0 +1,65 @@ +## `trax`: Train Neural Nets with JAX + +![train tracks](https://images.pexels.com/photos/461772/pexels-photo-461772.jpeg?dl&fit=crop&crop=entropy&w=640&h=426) + +### `trax`: T2T Radically Simpler with JAX + +*Why?* Because T2T has gotten too complex. We are simplifying the main code too, +but we wanted to try a more radical step. So you can write code as in pure +NumPy and debug directly. So you can easily pinpoint each line where things +happen and understand each function. But we also want it to run fast on +accelerators, and that's possible with [JAX](https://github.com/google/jax). + +*Status:* preview; things work: models train, checkpoints are saved, TensorBoard +has summaries, you can decode. But we are changing a lot every day for now. +Please let us know what we should add, delete, keep, change. We plan to move +the best parts into core JAX. + +*Entrypoints:* + +* Script: `trainer.py` +* Main library entrypoint: `trax.train` + +### Examples + +#### Example Colab + +See our example constructing language models from scratch in a GPU-backed colab notebook at +[Trax Demo](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/trax/notebooks/trax_demo_iclr2019.ipynb) + +#### MLP on MNIST + + +``` +python -m tensor2tensor.trax.trainer \ + --dataset=mnist \ + --model=MLP \ + --config="train.train_steps=1000" +``` + +#### Resnet50 on Imagenet + + +``` +python -m tensor2tensor.trax.trainer \ + --config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin +``` + +#### TransformerDecoder on LM1B + + +``` +python -m tensor2tensor.trax.trainer \ + --config_file=$PWD/trax/configs/transformer_lm1b_8gb.gin +``` + +### How `trax` differs from T2T + +* Configuration is done with [`gin`](https://github.com/google/gin-config). + `trainer.py` takes `--config_file` as well as `--config` for file overrides. +* Models are defined with [`stax`](https://github.com/google/jax/blob/master/jax/experimental/stax.py) in + `models/`. They are made gin-configurable in `models/__init__.py`. +* Datasets are simple iterators over batches. Datasets from + [`tensorflow/datasets`](https://github.com/tensorflow/datasets) + and [`tensor2tensor`](https://github.com/tensorflow/tensor2tensor) + are built-in and can be addressed by name. diff --git a/trax/__init__.py b/trax/__init__.py new file mode 100644 index 000000000..7fa0b7f96 --- /dev/null +++ b/trax/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/trax/backend.py b/trax/backend.py new file mode 100644 index 000000000..b4f2c87e6 --- /dev/null +++ b/trax/backend.py @@ -0,0 +1,308 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax backend: all the primitive functions needed.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import gin + +import jax +from jax import lax +from jax import random as jax_random +import jax.numpy as jnp +import jax.scipy.special as jax_special +import numpy as onp +import tensorflow_datasets as tfds + + + +def jax_conv(inp, fltr, window_strides, padding, dimension_numbers, + filter_dilation=None): + """A wrapper around `lax.conv_general_dilated`. + + It requires `dimension_numbers` and disallows `inp_dilation`. + + Args: + inp: an (N+2)-D array. The input of the convolution. + fltr: an (N+2)-D array. The filter (i.e. kernel) of the convolution. + window_strides: the strides for moving the convolution window. + padding: a string, either "VALID" or "SAME". The padding algorithm. + dimension_numbers: a tuple of three strings encoding the data format of + input, filter and output. "I" means input; "O" means output; "C" means + channel; other characters such as "W", "H" and "D" means spatial + dimensions. + filter_dilation: the dilation rates for the filter. Dilating the filter + means adding "holes" to the filter. + + Returns: + An (N+2)-D array. The convolution result. + """ + return lax.conv_general_dilated(inp, fltr, window_strides, padding, + lhs_dilation=None, + rhs_dilation=filter_dilation, + dimension_numbers=dimension_numbers) + + +def _pooling_general(inputs, reducer, init_val, rescaler=None, + pool_size=(2, 2), strides=None, padding="VALID"): + """Helper: general pooling computation used in pooling layers later.""" + spatial_strides = strides or (1,) * len(pool_size) + rescale = rescaler(pool_size, spatial_strides, padding) if rescaler else None + dims = (1,) + pool_size + (1,) # NHWC + strides = (1,) + spatial_strides + (1,) + out = lax.reduce_window(inputs, init_val, reducer, dims, strides, padding) + return rescale(out, inputs) if rescale else out + + +def jax_max_pool(x, pool_size, strides, padding): + return _pooling_general(x, lax.max, -jnp.inf, pool_size=pool_size, + strides=strides, padding=padding) + + +def jax_sum_pool(x, pool_size, strides, padding): + return _pooling_general(x, lax.add, 0., pool_size=pool_size, + strides=strides, padding=padding) + + +def _normalize_by_window_size(dims, spatial_strides, padding): # pylint: disable=invalid-name + def rescale(outputs, inputs): + one = jnp.ones(inputs.shape[1:-1], dtype=inputs.dtype) + window_sizes = lax.reduce_window( + one, 0., lax.add, dims, spatial_strides, padding) + return outputs / window_sizes[..., jnp.newaxis] + return rescale + + +def jax_avg_pool(x, pool_size, strides, padding): + return _pooling_general(x, lax.add, 0., _normalize_by_window_size, + pool_size, strides=strides, padding=padding) + + +def nested_map(x, f): + """Map the function f to the nested structure x (dicts, tuples, lists).""" + if isinstance(x, list): + return [nested_map(y, f) for y in x] + if isinstance(x, tuple): + return tuple([nested_map(y, f) for y in x]) + if isinstance(x, dict): + return {k: nested_map(v, f) for (k, v) in x.items()} + return f(x) + + +class ShapeType(object): + """Store shape and type.""" + + def __init__(self, shape, dtype): + self.shape = shape + self.dtype = dtype + + def __repr__(self): + return "[shape:" + str(self.shape) + ", dtype:" + str(self.dtype) + "]" + + +def jax_eval_on_shapes(f): + """Returns a function that evaluates `f` given input shapes and dtypes. + + It transforms function `f` to a function that performs the same computation as + `f` but only on shapes and dtypes (a.k.a. shape inference). + + Args: + f: the function to be transformed. + + Returns: + A function whose input arguments can be either the same as `f`'s or only + their shapes/dtypes represented by `ShapeType`, and whose return values are + `ShapeType`s with the same nested structure as `f`'s return values. + """ + def shape_fun(*args, **kwargs): + jax_shapes = jax.eval_shape(f, *args, **kwargs) + return nested_map(jax_shapes, lambda x: ShapeType(x.shape, x.dtype)) + return shape_fun + + +# The default value of dtype is different from jax_random.randint +def jax_randint(key, shape, minval, maxval, dtype=onp.int32): + """Sample uniform random values in [minval, maxval) with given shape/dtype. + + Args: + key: a PRNGKey used as the random key. + shape: a tuple of nonnegative integers representing the shape. + minval: int or array of ints broadcast-compatible with ``shape``, a minimum + (inclusive) value for the range. + maxval: int or array of ints broadcast-compatible with ``shape``, a maximum + (exclusive) value for the range. + dtype: optional, an int dtype for the returned values (default int32). + + Returns: + A random array with the specified shape and dtype. + """ + return jax_random.randint(key, shape, minval=minval, maxval=maxval, + dtype=dtype) + + +_JAX_BACKEND = { + "name": "jax", + "np": jnp, + "logsumexp": jax_special.logsumexp, + "expit": jax_special.expit, + "erf": jax_special.erf, + "conv": jax_conv, + "avg_pool": jax_avg_pool, + "max_pool": jax_max_pool, + "sum_pool": jax_sum_pool, + "jit": jax.jit, + "grad": jax.grad, + "pmap": jax.pmap, + "eval_on_shapes": jax_eval_on_shapes, + "random_uniform": jax_random.uniform, + "random_randint": jax_randint, + "random_normal": jax_random.normal, + "random_bernoulli": jax_random.bernoulli, + "random_get_prng": jax.jit(jax_random.PRNGKey), + "random_split": jax_random.split, + "dataset_as_numpy": tfds.as_numpy, +} + + +_NUMPY_BACKEND = { + "name": "numpy", + "np": onp, + "jit": (lambda f: f), + "random_get_prng": lambda seed: None, + "random_split": lambda prng, num=2: (None,) * num, + "expit": (lambda x: 1. / (1. + onp.exp(-x))), +} + + +def get_name(): + return backend()["name"] + + +def logsumexp(*args, **kwargs): + return backend()["logsumexp"](*args, **kwargs) + + +def expit(*args, **kwargs): + return backend()["expit"](*args, **kwargs) + + +def erf(*args, **kwargs): + return backend()["erf"](*args, **kwargs) + + +def conv(*args, **kwargs): + return backend()["conv"](*args, **kwargs) + + +def avg_pool(*args, **kwargs): + return backend()["avg_pool"](*args, **kwargs) + + +def max_pool(*args, **kwargs): + return backend()["max_pool"](*args, **kwargs) + + +def sum_pool(*args, **kwargs): + return backend()["sum_pool"](*args, **kwargs) + + +def jit(*args, **kwargs): + return backend()["jit"](*args, **kwargs) + + +def grad(*args, **kwargs): + return backend()["grad"](*args, **kwargs) + + +def pmap(*args, **kwargs): + return backend()["pmap"](*args, **kwargs) + + +def eval_on_shapes(*args, **kwargs): + return backend()["eval_on_shapes"](*args, **kwargs) + + +def dataset_as_numpy(*args, **kwargs): + return backend()["dataset_as_numpy"](*args, **kwargs) + + +# For numpy and random modules, we need to call "backend()" lazily, only when +# the function is called -- so that it can be set by gin configs. +# (Otherwise, backend() is called on import before gin-config is parsed.) +# To do that, we make objects to encapsulated these modules. + + +class RandomBackend(object): + """Backend providing random functions.""" + + def get_prng(self, seed): + return backend()["random_get_prng"](seed) + + def split(self, prng, num=2): + return backend()["random_split"](prng, num) + + def uniform(self, *args, **kwargs): + return backend()["random_uniform"](*args, **kwargs) + + def randint(self, *args, **kwargs): + return backend()["random_randint"](*args, **kwargs) + + def normal(self, *args, **kwargs): + return backend()["random_normal"](*args, **kwargs) + + def bernoulli(self, *args, **kwargs): + return backend()["random_bernoulli"](*args, **kwargs) + + +random = RandomBackend() + + +# A class that just forwards attribute accesses to backend's numpy object. +class NumpyBackend(object): + + def __getattr__(self, attr): + return getattr(backend()["np"], attr) + + +numpy = NumpyBackend() + + + + +override_backend_name = None + + +@gin.configurable() +def backend(name="jax"): + name = name if not override_backend_name else override_backend_name + if name == "numpy": + return _NUMPY_BACKEND + return _JAX_BACKEND + + +@contextlib.contextmanager +def use_backend(name): + global override_backend_name + prev_name = override_backend_name + override_backend_name = name + # Run the decorated function in try-finally in case it throws, e.g. for tests. + try: + yield + finally: + override_backend_name = prev_name diff --git a/trax/backend_test.py b/trax/backend_test.py new file mode 100644 index 000000000..98cf7928f --- /dev/null +++ b/trax/backend_test.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.backend.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin +import jax.numpy as jnp +import numpy as onp +from tensorflow import test +from trax import backend as backend_lib + + +class BackendTest(test.TestCase): + + def setUp(self): + gin.clear_config() + + def override_gin(self, bindings): + gin.parse_config_files_and_bindings(None, bindings) + + def test_backend_imports_correctly(self): + backend = backend_lib.backend() + self.assertEqual(jnp, backend["np"]) + self.assertNotEqual(onp, backend["np"]) + + self.override_gin("backend.name = 'numpy'") + + backend = backend_lib.backend() + self.assertNotEqual(jnp, backend["np"]) + self.assertEqual(onp, backend["np"]) + + def test_numpy_backend_delegation(self): + # Assert that we are getting JAX's numpy backend. + backend = backend_lib.backend() + numpy = backend_lib.numpy + self.assertEqual(jnp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(jnp.isinf, numpy.isinf) + self.assertEqual(jnp.inf, numpy.inf) + + # Assert that we will now get the pure numpy backend. + + self.override_gin("backend.name = 'numpy'") + + backend = backend_lib.backend() + numpy = backend_lib.numpy + self.assertEqual(onp, backend["np"]) + + # Assert that `numpy` calls the appropriate gin configured functions and + # properties. + self.assertTrue(numpy.isinf(numpy.inf)) + self.assertEqual(onp.isinf, numpy.isinf) + self.assertEqual(onp.inf, numpy.inf) + +if __name__ == "__main__": + test.main() diff --git a/trax/configs/mlp_mnist.gin b/trax/configs/mlp_mnist.gin new file mode 100644 index 000000000..74130fb07 --- /dev/null +++ b/trax/configs/mlp_mnist.gin @@ -0,0 +1,49 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.learning_rate +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'mnist' + +# Parameters for MLP: +# ============================================================================== +MLP.d_hidden = 512 +MLP.n_hidden_layers = 2 +MLP.n_output_classes = 10 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for train: +# ============================================================================== +train.optimizer = @trax.optimizers.Adafactor +train.eval_frequency = 200 +train.eval_steps = 10 +train.model = @trax.models.MLP +train.train_steps = 2000 diff --git a/trax/configs/mlp_mnist.ginE b/trax/configs/mlp_mnist.ginE new file mode 100644 index 000000000..75cbf7ba9 --- /dev/null +++ b/trax/configs/mlp_mnist.ginE @@ -0,0 +1,35 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.learning_rate +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'mnist' + +# Parameters for MLP: +# ============================================================================== +MLP.d_hidden = 512 +MLP.n_hidden_layers = 2 +MLP.n_output_classes = 10 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for train: +# ============================================================================== +train.optimizer = @trax.optimizers.Adafactor +train.eval_frequency = 200 +train.eval_steps = 10 +train.model = @trax.models.MLP +train.train_steps = 2000 diff --git a/trax/configs/position_lookup_transformer_copy.gin b/trax/configs/position_lookup_transformer_copy.gin new file mode 100644 index 000000000..13c7bbf2d --- /dev/null +++ b/trax/configs/position_lookup_transformer_copy.gin @@ -0,0 +1,68 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'position_lookup_copy' + +# Parameters for sequence_copy_inputs: +# ============================================================================== +sequence_copy_inputs.vocab_size = 128 +sequence_copy_inputs.batch_size = 16 +sequence_copy_inputs.train_lengths = [20, 30, 40] +sequence_copy_inputs.eval_lengths = [60] +sequence_copy_inputs.reverse = False + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.05 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 16000 + +# Parameters for PositionLookupTransformerLM: +# ============================================================================== +PositionLookupTransformerLM.d_model = 256 +PositionLookupTransformerLM.d_ff = 512 +PositionLookupTransformerLM.dropout = 0.01 +PositionLookupTransformerLM.max_len = 62 +PositionLookupTransformerLM.n_heads = 4 +PositionLookupTransformerLM.n_layers = 3 +PositionLookupTransformerLM.vocab_size = 128 + +# Parameters for TransformerLM: (same as above, for easy comparisons) +# ============================================================================== +TransformerLM.d_model = 256 +TransformerLM.d_ff = 512 +TransformerLM.dropout = 0.01 +TransformerLM.max_len = 62 +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 3 +TransformerLM.vocab_size = 128 + +# Parameters for train: +# ============================================================================== +train.inputs = @trax.inputs.sequence_copy_inputs +train.eval_frequency = 1000 +train.eval_steps = 10 +train.model = @trax.models.PositionLookupTransformerLM +train.optimizer = @trax.optimizers.Adam +train.train_steps = 100000 +train.mask_id = 0 +train.has_weights = True diff --git a/trax/configs/position_lookup_transformer_copy.ginE b/trax/configs/position_lookup_transformer_copy.ginE new file mode 100644 index 000000000..036310433 --- /dev/null +++ b/trax/configs/position_lookup_transformer_copy.ginE @@ -0,0 +1,54 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'position_lookup_copy' + +# Parameters for sequence_copy_inputs: +# ============================================================================== +sequence_copy_inputs.vocab_size = 128 +sequence_copy_inputs.batch_size = 16 +sequence_copy_inputs.train_lengths = [20, 30, 40] +sequence_copy_inputs.eval_lengths = [60] +sequence_copy_inputs.reverse = False + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.05 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 16000 + +# Parameters for PositionLookupTransformerLM: +# ============================================================================== +PositionLookupTransformerLM.d_model = 256 +PositionLookupTransformerLM.d_ff = 512 +PositionLookupTransformerLM.dropout = 0.01 +PositionLookupTransformerLM.max_len = 62 +PositionLookupTransformerLM.n_heads = 4 +PositionLookupTransformerLM.n_layers = 3 +PositionLookupTransformerLM.vocab_size = 128 + +# Parameters for TransformerLM: (same as above, for easy comparisons) +# ============================================================================== +TransformerLM.d_model = 256 +TransformerLM.d_ff = 512 +TransformerLM.dropout = 0.01 +TransformerLM.max_len = 62 +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 3 +TransformerLM.vocab_size = 128 + +# Parameters for train: +# ============================================================================== +train.inputs = @trax.inputs.sequence_copy_inputs +train.eval_frequency = 1000 +train.eval_steps = 10 +train.model = @trax.models.PositionLookupTransformerLM +train.optimizer = @trax.optimizers.Adam +train.train_steps = 100000 +train.mask_id = 0 +train.has_weights = True diff --git a/trax/configs/reformer_base_sweep.yaml b/trax/configs/reformer_base_sweep.yaml new file mode 100644 index 000000000..a8b3925c9 --- /dev/null +++ b/trax/configs/reformer_base_sweep.yaml @@ -0,0 +1,16 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +share_qk: [True, False] +attn_kv: [64, 128] diff --git a/trax/configs/reformer_base_sweep.yamlE b/trax/configs/reformer_base_sweep.yamlE new file mode 100644 index 000000000..490518f1e --- /dev/null +++ b/trax/configs/reformer_base_sweep.yamlE @@ -0,0 +1,2 @@ +share_qk: [True, False] +attn_kv: [64, 128] diff --git a/trax/configs/reformer_enwik8.gin b/trax/configs/reformer_enwik8.gin new file mode 100644 index 000000000..f66af0846 --- /dev/null +++ b/trax/configs/reformer_enwik8.gin @@ -0,0 +1,125 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters that will vary between experiments: +# ============================================================================== +train.model = @trax.models.ReformerLM +attn_type = @TimeBinCausalAttention +share_qk = True # Required when using LSHCausalAttention +attn_kv = 64 +n_layers = 3 +dropout = 0.1 + +# MemoryEfficientCausalAttention: full attention +# (no hparams to vary between experiments) + +# TimeBinCausalAttention: attend to nearby items +TimeBinCausalAttention.n_bins = 512 + +# LSHCausalAttention: locality-sensitive hashing (LSH) attention +LSHCausalAttention.n_bins = 512 +LSHCausalAttention.n_buckets = 1024 # Always 2 * n_bins +LSHCausalAttention.n_hashes = 2 +LSHCausalAttention.drop_for_hash_rate = 0.0 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 8 +batch_fn.max_eval_length = 65536 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_enwik8_l65k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 2.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 8 +train.inputs = @trax.inputs.inputs +# train.model: see top +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 60000 +train.save_graphs = False +train.save_steps = \ + [1000, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, + 55000, 60000] + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 256 +MemoryEfficientCausalAttention.share_qk = %share_qk + +# Parameters for TimeBinCausalAttention: +# ============================================================================== +TimeBinCausalAttention.dropout = 0.0 +# TimeBinCausalAttention.n_bins: see top +TimeBinCausalAttention.share_qk = %share_qk + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = False +LSHCausalAttention.rehash_each_round = True +# LSHCausalAttention.n_bins: see top +# LSHCausalAttention.n_buckets: see top +# LSHCausalAttention.n_hashes: see top +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +# LSHCausalAttention.drop_for_hash_rate: see top + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = %attn_type +TransformerLM.d_attention_key = %attn_kv +TransformerLM.d_attention_value = %attn_kv +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 4096 +TransformerLM.dropout = %dropout +TransformerLM.max_len = 65536 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = %n_layers +TransformerLM.share_qk = %share_qk +TransformerLM.vocab_size = 258 # Includes pad token and unused EOS token + +# Parameters for ReformerLM: +# ============================================================================== +ReformerLM.attention_type = %attn_type +ReformerLM.d_attention_key = %attn_kv +ReformerLM.d_attention_value = %attn_kv +ReformerLM.d_model = 1024 +ReformerLM.d_ff = 4096 +ReformerLM.dropout = %dropout +ReformerLM.max_len = 65536 +ReformerLM.mode = 'train' +ReformerLM.n_heads = 8 +ReformerLM.n_layers = %n_layers +ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token +ReformerLM.share_qk = %share_qk diff --git a/trax/configs/reformer_enwik8.ginE b/trax/configs/reformer_enwik8.ginE new file mode 100644 index 000000000..17b3e9883 --- /dev/null +++ b/trax/configs/reformer_enwik8.ginE @@ -0,0 +1,115 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters that will vary between experiments: +# ============================================================================== +train.model = @trax.models.ReformerLM +attn_type = @TimeBinCausalAttention +share_qk = True # Required when using LSHCausalAttention +attn_kv = 64 +n_layers = 3 +dropout = 0.1 + +# MemoryEfficientCausalAttention: full attention +# (no hparams to vary between experiments) + +# TimeBinCausalAttention: attend to nearby items +TimeBinCausalAttention.n_bins = 512 + +# LSHCausalAttention: locality-sensitive hashing (LSH) attention +LSHCausalAttention.n_bins = 512 +LSHCausalAttention.n_buckets = 1024 # Always 2 * n_bins +LSHCausalAttention.n_hashes = 2 +LSHCausalAttention.drop_for_hash_rate = 0.0 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 8 +batch_fn.max_eval_length = 65536 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_enwik8_l65k' +inputs.input_name = 'targets' +inputs.n_chunks = 16 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 2.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 8 +train.inputs = @trax.inputs.inputs +# train.model: see top +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 60000 +train.trainer_class = @MemoryEfficientTrainer +train.save_steps = \ + [1000, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, + 55000, 60000] + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 256 +MemoryEfficientCausalAttention.share_qk = %share_qk + +# Parameters for TimeBinCausalAttention: +# ============================================================================== +TimeBinCausalAttention.dropout = 0.0 +# TimeBinCausalAttention.n_bins: see top +TimeBinCausalAttention.share_qk = %share_qk + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = False +LSHCausalAttention.rehash_each_round = True +# LSHCausalAttention.n_bins: see top +# LSHCausalAttention.n_buckets: see top +# LSHCausalAttention.n_hashes: see top +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +# LSHCausalAttention.drop_for_hash_rate: see top + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = %attn_type +TransformerLM.d_attention_key = %attn_kv +TransformerLM.d_attention_value = %attn_kv +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 4096 +TransformerLM.dropout = %dropout +TransformerLM.max_len = 65536 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = %n_layers +TransformerLM.n_chunks = 16 +TransformerLM.share_qk = %share_qk +TransformerLM.vocab_size = 258 # Includes pad token and unused EOS token + +# Parameters for ReformerLM: +# ============================================================================== +ReformerLM.attention_type = %attn_type +ReformerLM.d_attention_key = %attn_kv +ReformerLM.d_attention_value = %attn_kv +ReformerLM.d_model = 1024 +ReformerLM.d_ff = 4096 +ReformerLM.dropout = %dropout +ReformerLM.max_len = 65536 +ReformerLM.mode = 'train' +ReformerLM.n_heads = 8 +ReformerLM.n_layers = %n_layers +ReformerLM.vocab_size = 258 # Includes pad token and unused EOS token +ReformerLM.n_chunks = 16 +ReformerLM.n_attention_chunks = 1 +ReformerLM.share_qk = %share_qk diff --git a/trax/configs/reformer_hash_sweep.yaml b/trax/configs/reformer_hash_sweep.yaml new file mode 100644 index 000000000..8e7364758 --- /dev/null +++ b/trax/configs/reformer_hash_sweep.yaml @@ -0,0 +1,15 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MergedMultiHashedCausalAttentionV2.n_hashes: [2, 4, 8, 16] diff --git a/trax/configs/reformer_hash_sweep.yamlE b/trax/configs/reformer_hash_sweep.yamlE new file mode 100644 index 000000000..94216b406 --- /dev/null +++ b/trax/configs/reformer_hash_sweep.yamlE @@ -0,0 +1 @@ +MergedMultiHashedCausalAttentionV2.n_hashes: [2, 4, 8, 16] diff --git a/trax/configs/reformer_imagenet64.gin b/trax/configs/reformer_imagenet64.gin new file mode 100644 index 000000000..5892295e8 --- /dev/null +++ b/trax/configs/reformer_imagenet64.gin @@ -0,0 +1,124 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters that will vary between experiments: +# ============================================================================== +train.model = @trax.models.ReformerLM +attn_type = @TimeBinCausalAttention +share_qk = True # Required when using LSHCausalAttention +attn_kv = 64 +n_layers = 3 + +# MemoryEfficientCausalAttention: full attention +# (no hparams to vary between experiments) + +# TimeBinCausalAttention: attend to nearby items +TimeBinCausalAttention.n_bins = 64 + +# LSHCausalAttention: locality-sensitive hashing (LSH) attention +LSHCausalAttention.n_bins = 96 +LSHCausalAttention.n_buckets = 192 # Always 2 * n_bins +LSHCausalAttention.n_hashes = 2 +LSHCausalAttention.drop_for_hash_rate = 0.0 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 8 +batch_fn.max_eval_length = 12288 # 64 * 64 * 3 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet64_gen_flat_rev' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 2.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 64 +train.inputs = @trax.inputs.inputs +# train.model: see top +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 500000 +train.save_graphs = False +train.save_steps = \ + [1000, 5000, 10000, 20000, 40000, 60000, 80000, + 100000, 200000, 300000, 400000, 500000] + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 +MemoryEfficientCausalAttention.share_qk = %share_qk + +# Parameters for TimeBinCausalAttention: +# ============================================================================== +TimeBinCausalAttention.dropout = 0.0 +# TimeBinCausalAttention.n_bins: see top +TimeBinCausalAttention.share_qk = %share_qk + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = False +LSHCausalAttention.rehash_each_round = True +# LSHCausalAttention.n_bins: see top +# LSHCausalAttention.n_buckets: see top +# LSHCausalAttention.n_hashes: see top +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +# LSHCausalAttention.drop_for_hash_rate: see top + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = %attn_type +TransformerLM.d_attention_key = %attn_kv +TransformerLM.d_attention_value = %attn_kv +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 4096 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = 12288 # 64 * 64 * 3 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = %n_layers +TransformerLM.share_qk = %share_qk +TransformerLM.vocab_size = 256 + +# Parameters for ReformerLM: +# ============================================================================== +ReformerLM.attention_type = %attn_type +ReformerLM.d_attention_key = %attn_kv +ReformerLM.d_attention_value = %attn_kv +ReformerLM.d_model = 1024 +ReformerLM.d_ff = 4096 +ReformerLM.dropout = 0.0 +ReformerLM.max_len = 12288 # 64 * 64 * 3 +ReformerLM.mode = 'train' +ReformerLM.n_heads = 8 +ReformerLM.n_layers = %n_layers +ReformerLM.vocab_size = 256 +ReformerLM.share_qk = %share_qk diff --git a/trax/configs/reformer_imagenet64.ginE b/trax/configs/reformer_imagenet64.ginE new file mode 100644 index 000000000..9d9ae11c2 --- /dev/null +++ b/trax/configs/reformer_imagenet64.ginE @@ -0,0 +1,114 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters that will vary between experiments: +# ============================================================================== +train.model = @trax.models.ReformerLM +attn_type = @TimeBinCausalAttention +share_qk = True # Required when using LSHCausalAttention +attn_kv = 64 +n_layers = 3 + +# MemoryEfficientCausalAttention: full attention +# (no hparams to vary between experiments) + +# TimeBinCausalAttention: attend to nearby items +TimeBinCausalAttention.n_bins = 64 + +# LSHCausalAttention: locality-sensitive hashing (LSH) attention +LSHCausalAttention.n_bins = 96 +LSHCausalAttention.n_buckets = 192 # Always 2 * n_bins +LSHCausalAttention.n_hashes = 2 +LSHCausalAttention.drop_for_hash_rate = 0.0 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 8 +batch_fn.max_eval_length = 12288 # 64 * 64 * 3 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet64_gen_flat_rev' +inputs.input_name = 'targets' +inputs.n_chunks = 16 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 2.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 64 +train.inputs = @trax.inputs.inputs +# train.model: see top +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 500000 +train.trainer_class = @MemoryEfficientTrainer +train.save_steps = \ + [1000, 5000, 10000, 20000, 40000, 60000, 80000, + 100000, 200000, 300000, 400000, 500000] + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 +MemoryEfficientCausalAttention.share_qk = %share_qk + +# Parameters for TimeBinCausalAttention: +# ============================================================================== +TimeBinCausalAttention.dropout = 0.0 +# TimeBinCausalAttention.n_bins: see top +TimeBinCausalAttention.share_qk = %share_qk + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = False +LSHCausalAttention.rehash_each_round = True +# LSHCausalAttention.n_bins: see top +# LSHCausalAttention.n_buckets: see top +# LSHCausalAttention.n_hashes: see top +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +# LSHCausalAttention.drop_for_hash_rate: see top + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = %attn_type +TransformerLM.d_attention_key = %attn_kv +TransformerLM.d_attention_value = %attn_kv +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 4096 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = 12288 # 64 * 64 * 3 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = %n_layers +TransformerLM.n_chunks = 16 +TransformerLM.share_qk = %share_qk +TransformerLM.vocab_size = 256 + +# Parameters for ReformerLM: +# ============================================================================== +ReformerLM.attention_type = %attn_type +ReformerLM.d_attention_key = %attn_kv +ReformerLM.d_attention_value = %attn_kv +ReformerLM.d_model = 1024 +ReformerLM.d_ff = 4096 +ReformerLM.dropout = 0.0 +ReformerLM.max_len = 12288 # 64 * 64 * 3 +ReformerLM.mode = 'train' +ReformerLM.n_heads = 8 +ReformerLM.n_layers = %n_layers +ReformerLM.vocab_size = 256 +ReformerLM.n_chunks = 16 +ReformerLM.n_attention_chunks = 1 +ReformerLM.share_qk = %share_qk diff --git a/trax/configs/reformer_large_sweep.yaml b/trax/configs/reformer_large_sweep.yaml new file mode 100644 index 000000000..003013437 --- /dev/null +++ b/trax/configs/reformer_large_sweep.yaml @@ -0,0 +1,17 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MergedMultiHashedCausalAttentionV2.n_hashes: [2, 4] +TransformerRevnetLM.n_layers: [12, 16, 20, 24] +MultifactorSchedule.constant: [0.3, 1.0] diff --git a/trax/configs/reformer_large_sweep.yamlE b/trax/configs/reformer_large_sweep.yamlE new file mode 100644 index 000000000..9cbf67088 --- /dev/null +++ b/trax/configs/reformer_large_sweep.yamlE @@ -0,0 +1,3 @@ +MergedMultiHashedCausalAttentionV2.n_hashes: [2, 4] +TransformerRevnetLM.n_layers: [12, 16, 20, 24] +MultifactorSchedule.constant: [0.3, 1.0] diff --git a/trax/configs/resnet50_imagenet_8gb.gin b/trax/configs/resnet50_imagenet_8gb.gin new file mode 100644 index 000000000..c1b50b137 --- /dev/null +++ b/trax/configs/resnet50_imagenet_8gb.gin @@ -0,0 +1,58 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.learning_rate +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet224' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.2 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + + +# Parameters for Resnet50: +# ============================================================================== +Resnet50.d_hidden = 64 +Resnet50.n_output_classes = 1001 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 2000 +train.eval_steps = 20 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Resnet50 +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 1000000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + diff --git a/trax/configs/resnet50_imagenet_8gb.ginE b/trax/configs/resnet50_imagenet_8gb.ginE new file mode 100644 index 000000000..7eb13a225 --- /dev/null +++ b/trax/configs/resnet50_imagenet_8gb.ginE @@ -0,0 +1,44 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.learning_rate +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet224' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.2 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + + +# Parameters for Resnet50: +# ============================================================================== +Resnet50.d_hidden = 64 +Resnet50.n_output_classes = 1001 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 2000 +train.eval_steps = 20 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Resnet50 +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 1000000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + diff --git a/trax/configs/resnet50_imagenet_8gb_testing.gin b/trax/configs/resnet50_imagenet_8gb_testing.gin new file mode 100644 index 000000000..a8e2ec3d9 --- /dev/null +++ b/trax/configs/resnet50_imagenet_8gb_testing.gin @@ -0,0 +1,58 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.learning_rate +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet224' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.2 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + + +# Parameters for Resnet50: +# ============================================================================== +Resnet50.d_hidden = 64 +Resnet50.n_output_classes = 1001 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 2000 +train.eval_steps = 20 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Resnet50 +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 100000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + diff --git a/trax/configs/resnet50_imagenet_8gb_testing.ginE b/trax/configs/resnet50_imagenet_8gb_testing.ginE new file mode 100644 index 000000000..6a49f89fd --- /dev/null +++ b/trax/configs/resnet50_imagenet_8gb_testing.ginE @@ -0,0 +1,44 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.learning_rate +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet224' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.2 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + + +# Parameters for Resnet50: +# ============================================================================== +Resnet50.d_hidden = 64 +Resnet50.n_output_classes = 1001 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 2000 +train.eval_steps = 20 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Resnet50 +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 100000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + diff --git a/trax/configs/transformer_big_lm1b_8gb.gin b/trax/configs/transformer_big_lm1b_8gb.gin new file mode 100644 index 000000000..b7d11d496 --- /dev/null +++ b/trax/configs/transformer_big_lm1b_8gb.gin @@ -0,0 +1,68 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 16000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 512 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.SM3 +train.train_steps = 500000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 8192 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 8 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_big_lm1b_8gb.ginE b/trax/configs/transformer_big_lm1b_8gb.ginE new file mode 100644 index 000000000..1125eb032 --- /dev/null +++ b/trax/configs/transformer_big_lm1b_8gb.ginE @@ -0,0 +1,54 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 32 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 16000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 512 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.SM3 +train.train_steps = 500000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 8192 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 8 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_copy.gin b/trax/configs/transformer_copy.gin new file mode 100644 index 000000000..564f0286c --- /dev/null +++ b/trax/configs/transformer_copy.gin @@ -0,0 +1,91 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +n_symbols = 128 +length = 4096 +batch = 16 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size = %batch +batch_fn.eval_batch_size = %batch +batch_fn.max_eval_length = %length + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'sequence_copy' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for sequence_copy_inputs: +# ============================================================================== +sequence_copy_inputs.vocab_size = %n_symbols +sequence_copy_inputs.batch_size = %batch +sequence_copy_inputs.train_lengths = [%length] +sequence_copy_inputs.eval_lengths = [%length] +sequence_copy_inputs.reverse = False + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 64 +train.inputs = @trax.inputs.sequence_copy_inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 50000 +train.has_weights = True + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = True +LSHCausalAttention.rehash_each_round = True +LSHCausalAttention.n_bins = 64 +LSHCausalAttention.n_buckets = 128 +LSHCausalAttention.n_hashes = 8 +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +LSHCausalAttention.drop_for_hash_rate = 0.1 +LSHCausalAttention.factorize_hash = True + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.LSHCausalAttention +TransformerLM.d_attention_key = 64 +TransformerLM.d_attention_value = 64 +TransformerLM.d_model = 256 +TransformerLM.d_ff = 256 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = %length +TransformerLM.mode = 'train' +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 1 +TransformerLM.share_qk = True +TransformerLM.vocab_size = %n_symbols diff --git a/trax/configs/transformer_copy.ginE b/trax/configs/transformer_copy.ginE new file mode 100644 index 000000000..86bfbead8 --- /dev/null +++ b/trax/configs/transformer_copy.ginE @@ -0,0 +1,77 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +n_symbols = 128 +length = 4096 +batch = 16 + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size = %batch +batch_fn.eval_batch_size = %batch +batch_fn.max_eval_length = %length + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'sequence_copy' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for sequence_copy_inputs: +# ============================================================================== +sequence_copy_inputs.vocab_size = %n_symbols +sequence_copy_inputs.batch_size = %batch +sequence_copy_inputs.train_lengths = [%length] +sequence_copy_inputs.eval_lengths = [%length] +sequence_copy_inputs.reverse = False + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 64 +train.inputs = @trax.inputs.sequence_copy_inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 50000 +train.has_weights = True + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 + +# Parameters for LSHCausalAttention: +# ============================================================================== +LSHCausalAttention.allow_duplicate_attention = False +LSHCausalAttention.attend_across_buckets = True +LSHCausalAttention.rehash_each_round = True +LSHCausalAttention.n_bins = 64 +LSHCausalAttention.n_buckets = 128 +LSHCausalAttention.n_hashes = 8 +LSHCausalAttention.one_rng = False +LSHCausalAttention.hard_k = 0 +LSHCausalAttention.dropout = 0.0 +LSHCausalAttention.drop_for_hash_rate = 0.1 +LSHCausalAttention.factorize_hash = True + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.LSHCausalAttention +TransformerLM.d_attention_key = 64 +TransformerLM.d_attention_value = 64 +TransformerLM.d_model = 256 +TransformerLM.d_ff = 256 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = %length +TransformerLM.mode = 'train' +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 1 +TransformerLM.share_qk = True +TransformerLM.vocab_size = %n_symbols diff --git a/trax/configs/transformer_imdb_8gb.gin b/trax/configs/transformer_imdb_8gb.gin new file mode 100644 index 000000000..c35bd565e --- /dev/null +++ b/trax/configs/transformer_imdb_8gb.gin @@ -0,0 +1,68 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_sentiment_imdb' +inputs.input_name = 'targets' + + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 100 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerEncoder +train.train_steps = 1000 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerEncoder.d_model = 512 +TransformerEncoder.d_ff = 2048 +TransformerEncoder.dropout = 0.1 +TransformerEncoder.max_len = 2048 +TransformerEncoder.mode = 'train' +TransformerEncoder.n_classes = 10 +TransformerEncoder.n_heads = 8 +TransformerEncoder.n_layers = 6 +TransformerEncoder.vocab_size = 32000 diff --git a/trax/configs/transformer_imdb_8gb.ginE b/trax/configs/transformer_imdb_8gb.ginE new file mode 100644 index 000000000..b4cd87f81 --- /dev/null +++ b/trax/configs/transformer_imdb_8gb.ginE @@ -0,0 +1,54 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_sentiment_imdb' +inputs.input_name = 'targets' + + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 100 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerEncoder +train.train_steps = 1000 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerEncoder.d_model = 512 +TransformerEncoder.d_ff = 2048 +TransformerEncoder.dropout = 0.1 +TransformerEncoder.max_len = 2048 +TransformerEncoder.mode = 'train' +TransformerEncoder.n_classes = 10 +TransformerEncoder.n_heads = 8 +TransformerEncoder.n_layers = 6 +TransformerEncoder.vocab_size = 32000 diff --git a/trax/configs/transformer_lm1b_16gb.gin b/trax/configs/transformer_lm1b_16gb.gin new file mode 100644 index 000000000..cd39583ae --- /dev/null +++ b/trax/configs/transformer_lm1b_16gb.gin @@ -0,0 +1,141 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 1 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 50000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_lm1b_16gb.ginE b/trax/configs/transformer_lm1b_16gb.ginE new file mode 100644 index 000000000..d0725cdbb --- /dev/null +++ b/trax/configs/transformer_lm1b_16gb.ginE @@ -0,0 +1,127 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 1 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 50000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_lm1b_8gb.gin b/trax/configs/transformer_lm1b_8gb.gin new file mode 100644 index 000000000..87a4275f8 --- /dev/null +++ b/trax/configs/transformer_lm1b_8gb.gin @@ -0,0 +1,68 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.3 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.SM3 +train.train_steps = 500000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_lm1b_8gb.ginE b/trax/configs/transformer_lm1b_8gb.ginE new file mode 100644 index 000000000..0cfd5c434 --- /dev/null +++ b/trax/configs/transformer_lm1b_8gb.ginE @@ -0,0 +1,54 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.3 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.SM3 +train.train_steps = 500000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_lm1b_8gb_testing.gin b/trax/configs/transformer_lm1b_8gb_testing.gin new file mode 100644 index 000000000..336bcc0aa --- /dev/null +++ b/trax/configs/transformer_lm1b_8gb_testing.gin @@ -0,0 +1,68 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adam +train.train_steps = 100000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_lm1b_8gb_testing.ginE b/trax/configs/transformer_lm1b_8gb_testing.ginE new file mode 100644 index 000000000..a832553af --- /dev/null +++ b/trax/configs/transformer_lm1b_8gb_testing.ginE @@ -0,0 +1,54 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adam +train.train_steps = 100000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 diff --git a/trax/configs/transformer_lm_wmt_ende_16gb.gin b/trax/configs/transformer_lm_wmt_ende_16gb.gin new file mode 100644 index 000000000..f723333f6 --- /dev/null +++ b/trax/configs/transformer_lm_wmt_ende_16gb.gin @@ -0,0 +1,148 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.bucket_length=64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 511 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 1 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.train_steps = 50000 +train.optimizer = @trax.optimizers.Adafactor +train.has_weights = True +train.mask_id = 0 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_lm_wmt_ende_16gb.ginE b/trax/configs/transformer_lm_wmt_ende_16gb.ginE new file mode 100644 index 000000000..d4ea4390e --- /dev/null +++ b/trax/configs/transformer_lm_wmt_ende_16gb.ginE @@ -0,0 +1,134 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.bucket_length=64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 511 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 500 +train.eval_steps = 1 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.train_steps = 50000 +train.optimizer = @trax.optimizers.Adafactor +train.has_weights = True +train.mask_id = 0 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_lm_wmt_ende_8gb.gin b/trax/configs/transformer_lm_wmt_ende_8gb.gin new file mode 100644 index 000000000..dca394e8f --- /dev/null +++ b/trax/configs/transformer_lm_wmt_ende_8gb.gin @@ -0,0 +1,78 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.bucket_length=64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for mask: +# ============================================================================== +masked_mean.mask_id = 0 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 511 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adafactor +train.has_weights = True + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 diff --git a/trax/configs/transformer_lm_wmt_ende_8gb.ginE b/trax/configs/transformer_lm_wmt_ende_8gb.ginE new file mode 100644 index 000000000..ed43f7919 --- /dev/null +++ b/trax/configs/transformer_lm_wmt_ende_8gb.ginE @@ -0,0 +1,64 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.bucket_length=64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for mask: +# ============================================================================== +masked_mean.mask_id = 0 + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 511 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adafactor +train.has_weights = True + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 diff --git a/trax/configs/transformer_ptb_16gb.gin b/trax/configs/transformer_ptb_16gb.gin new file mode 100644 index 000000000..bfecdc189 --- /dev/null +++ b/trax/configs/transformer_ptb_16gb.gin @@ -0,0 +1,142 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.learning_rate +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 512 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_ptb10k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 200 +train.eval_steps = 2 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 20000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.5 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.5 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 10240 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_ptb_16gb.ginE b/trax/configs/transformer_ptb_16gb.ginE new file mode 100644 index 000000000..9d1b796ff --- /dev/null +++ b/trax/configs/transformer_ptb_16gb.ginE @@ -0,0 +1,128 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.learning_rate +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 512 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_ptb10k' +inputs.input_name = 'targets' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 200 +train.eval_steps = 2 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.TransformerLM +train.optimizer = @trax.optimizers.Adafactor +train.train_steps = 20000 +train.mask_id = 0 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.5 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.5 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 10240 + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for train: +# ============================================================================== +train.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.gin b/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.gin new file mode 100644 index 000000000..c61b11014 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.gin @@ -0,0 +1,73 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 20000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 100000 +train.optimizer = @trax.optimizers.Adafactor +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model = 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.ginE b/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.ginE new file mode 100644 index 000000000..40451973d --- /dev/null +++ b/trax/configs/transformer_wmt_ende_16gb_adafactor_testing.ginE @@ -0,0 +1,59 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 20000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 100000 +train.optimizer = @trax.optimizers.Adafactor +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model = 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_adafactor.gin b/trax/configs/transformer_wmt_ende_8gb_adafactor.gin new file mode 100644 index 000000000..42a35eb66 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_adafactor.gin @@ -0,0 +1,73 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adafactor +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model = 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_adafactor.ginE b/trax/configs/transformer_wmt_ende_8gb_adafactor.ginE new file mode 100644 index 000000000..db95a6236 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_adafactor.ginE @@ -0,0 +1,59 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 1.0 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adafactor: +# ============================================================================== +Adafactor.beta1 = 0.0 +Adafactor.decay_rate = 0.8 +Adafactor.clipping_threshold = 1.0 +Adafactor.epsilon1 = 1e-30 +Adafactor.epsilon2 = 0.001 +Adafactor.factored = True +Adafactor.multiply_by_parameter_scale = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adafactor +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model = 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_adam.gin b/trax/configs/transformer_wmt_ende_8gb_adam.gin new file mode 100644 index 000000000..7e10ccff1 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_adam.gin @@ -0,0 +1,70 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +# 0.044 ~= 512^-0.5 = d_model^-0.5 +MultifactorSchedule.constant = 0.044 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for Adam: +# ============================================================================== +Adam.b1 = 0.9 +Adam.b2 = 0.98 +Adam.eps = 1e-9 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adam +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model= 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_adam.ginE b/trax/configs/transformer_wmt_ende_8gb_adam.ginE new file mode 100644 index 000000000..2638003ac --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_adam.ginE @@ -0,0 +1,56 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +# 0.044 ~= 512^-0.5 = d_model^-0.5 +MultifactorSchedule.constant = 0.044 +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for Adam: +# ============================================================================== +Adam.b1 = 0.9 +Adam.b2 = 0.98 +Adam.eps = 1e-9 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.Adam +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model= 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_sm3.gin b/trax/configs/transformer_wmt_ende_8gb_sm3.gin new file mode 100644 index 000000000..29cb232a1 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_sm3.gin @@ -0,0 +1,67 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for SM3: +# ============================================================================== +SM3.momentum = 0.9 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.SM3 +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model= 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/transformer_wmt_ende_8gb_sm3.ginE b/trax/configs/transformer_wmt_ende_8gb_sm3.ginE new file mode 100644 index 000000000..04a4356b4 --- /dev/null +++ b/trax/configs/transformer_wmt_ende_8gb_sm3.ginE @@ -0,0 +1,53 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 64 +batch_fn.max_eval_length = 1024 +batch_fn.buckets_include_inputs_in_length=True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for MultifactorSchedule: +# ============================================================================== +MultifactorSchedule.constant = 0.1 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 8000 + +# Parameters for SM3: +# ============================================================================== +SM3.momentum = 0.9 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_preprocess +wmt_preprocess.max_length = 512 +wmt_preprocess.max_eval_length = 1024 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 1000 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.Transformer +train.train_steps = 500000 +train.optimizer = @trax.optimizers.SM3 +train.mask_id = 0 + +# Parameters for Transformer: +# ============================================================================== +Transformer.d_model= 512 +Transformer.d_ff = 2048 +Transformer.dropout = 0.1 +Transformer.max_len = 2048 +Transformer.mode = 'train' +Transformer.n_heads = 8 +Transformer.n_layers = 6 +Transformer.input_vocab_size = 33300 diff --git a/trax/configs/wide_resnet_cifar10_8gb.gin b/trax/configs/wide_resnet_cifar10_8gb.gin new file mode 100644 index 000000000..99c3cb76e --- /dev/null +++ b/trax/configs/wide_resnet_cifar10_8gb.gin @@ -0,0 +1,95 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.learning_rate +import trax.models +import trax.optimizers +import trax.trainer_lib + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.5 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 +Momentum.weight_decay_rate = 5e-4 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.cifar10_augmentation_preprocess + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 4 +WideResnet.n_output_classes = 10 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 100 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.WideResnet +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 10000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 0.1, (1e-9, 10.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 0.1), False), + ("mass", 0.9, (0.0, 0.99), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/configs/wide_resnet_cifar10_8gb.ginE b/trax/configs/wide_resnet_cifar10_8gb.ginE new file mode 100644 index 000000000..7edefa0e4 --- /dev/null +++ b/trax/configs/wide_resnet_cifar10_8gb.ginE @@ -0,0 +1,81 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.learning_rate +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for MultifactorSchedule: +# ============================================================================== +EvalAdjustingSchedule.constant = 0.5 +MultifactorSchedule.factors = 'constant * linear_warmup' +MultifactorSchedule.warmup_steps = 400 + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 +Momentum.weight_decay_rate = 5e-4 + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.cifar10_augmentation_preprocess + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 4 +WideResnet.n_output_classes = 10 + +# Parameters for train: +# ============================================================================== +train.eval_frequency = 100 +train.eval_steps = 10 +train.inputs = @trax.inputs.inputs +train.model = @trax.models.WideResnet +train.optimizer = @trax.optimizers.Momentum +train.train_steps = 10000 +train.lr_schedule = @learning_rate.EvalAdjustingSchedule + +# ============================================================================== +# Parameters for the RL hyperparameter tuner; turn on with +# train.lr_schedule=@learning_rate.PolicySchedule and set +# PolicySchedule.policy_dir. +# ============================================================================== + +# Parameters for PolicySchedule: +# ============================================================================== +PolicySchedule.observation_metrics = ( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), +) +PolicySchedule.include_controls_in_observation = False +PolicySchedule.control_configs = ( + ("learning_rate", 0.1, (1e-9, 10.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 0.1), False), + ("mass", 0.9, (0.0, 0.99), True), +) +PolicySchedule.observation_range = (0.0, 10.0) +PolicySchedule.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +PolicySchedule.policy_and_value_model = @trax.models.TransformerDecoder +PolicySchedule.policy_and_value_two_towers = False + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 2 diff --git a/trax/history.py b/trax/history.py new file mode 100644 index 000000000..c6495daf4 --- /dev/null +++ b/trax/history.py @@ -0,0 +1,78 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""trax history.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl import logging + + +class History(object): + """History of metrics. + + History contains the metrics recorded during training and evaluation. + Save data with history.append and get a sequence of data by calling + history.get. + + For example: + history.append("train", "metrics/accuracy", 1, 0.04) + history.append("train", "metrics/accuracy", 1000, 0.31) + history.get("train", "metrics/accuracy") + # returns [(1, 0.04), (1000, 0.31)] + """ + + def __init__(self): + # Structure is + # values = { + # "mode1": { + # "metric1": [val1, val2], + # ... + # }, + # "mode2": ... + # } + self._values = {} + + def append(self, mode, metric, step, value): + """Append (step, value) pair to history for the given mode and metric.""" + if mode not in self._values: + self._values[mode] = collections.defaultdict(list) + self._values[mode][metric].append((step, value)) + + def get(self, mode, metric): + """Get the history for the given metric and mode.""" + if mode not in self._values: + logging.info("Metric %s not found for mode %s", metric, mode) + return [] + return list(self._values[mode][metric]) + + @property + def modes(self): + """Current tracked modes.""" + return sorted(list(self._values.keys())) + + def metrics_for_mode(self, mode): + """Metrics available for a given mode.""" + if mode not in self._values: + logging.info("Mode %s not found", mode) + return [] + return sorted(list(self._values[mode].keys())) + + def __str__(self): + return str(self._values) diff --git a/trax/inputs.py b/trax/inputs.py new file mode 100644 index 000000000..7a441c45b --- /dev/null +++ b/trax/inputs.py @@ -0,0 +1,648 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""trax input pipeline.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import random + +import gin +import numpy as onp + +from tensor2tensor import problems_colab as t2t_problems +import tensorflow as tf +import tensorflow_datasets as tfds +from trax import backend +from trax.backend import numpy as np + +# Inputs is the trax tuple defining the input streams and shapes. +# * train_stream: training data that will be used for training +# may include all the augmentation or selection the training wants +# the shape of examples is [batch_fn.batch_size, ...] +# * train_eval_stream: training data used for evaluation +# examples from training data but usually without augmentation +# the shape of examples is [batch_fn.eval_batch_size, ...] +# * eval_stream: evaluation data stream +# examples from evaluation data, usually without augmentation +# the shape of examples is [batch_fn.eval_batch_size, ...] +# * input_shape: the shape of inputs +# the [...] above, without batch size +# * input_dtype: the data type of inputs +# * target_shape: the shape of targets +# the [...] above, without batch size +# * target_dtype: the data type of targets + +Inputs = collections.namedtuple( + '_Inputs', + ['train_stream', 'train_eval_stream', 'eval_stream', + 'input_shape', 'input_dtype', 'target_shape', 'target_dtype'] +) + +# How many examples from the stream to skip at random during training. +# For now, we skip at most 100K examples for efficiency. +# TODO(lukaszkaiser): can we improve efficiency, should that be changed? +_MAX_SKIP_EXAMPLES = 1e5 + + +def download_and_prepare(dataset_name, data_dir): + """Downloads and prepares T2T or TFDS dataset. + + Args: + dataset_name: tfds dataset or t2t problem name prefixed by "t2t_". + data_dir: location of existing dataset or None. + + Returns: + data_dir: path string of downloaded data. + """ + if not data_dir: + data_dir = os.path.expanduser('~/tensorflow_datasets/') + dl_dir = os.path.join(data_dir, 'download') + tf.logging.info( + ('No dataset directory provided. ' + 'Downloading and generating dataset for %s inside data directory %s ' + 'For large datasets it is better to prepare datasets manually!') + % (dataset_name, data_dir)) + if dataset_name.startswith('t2t_'): + # Download and run dataset generator for T2T problem. + data_dir = os.path.join(data_dir, dataset_name) + tf.gfile.MakeDirs(data_dir) + tf.gfile.MakeDirs(dl_dir) + t2t_problems.problem( + dataset_name[len('t2t_'):]).generate_data(data_dir, dl_dir) + else: + # Download and prepare TFDS dataset. + tfds_builder = tfds.builder(dataset_name) + tfds_builder.download_and_prepare(download_dir=dl_dir) + else: + data_dir = os.path.expanduser(data_dir) + return data_dir + + +@gin.configurable(blacklist=['n_devices']) +def inputs(n_devices, dataset_name, data_dir=None, input_name=None, + n_chunks=0): + """Make Inputs for built-in datasets. + + Args: + n_devices: how many devices to build the inputs for. + dataset_name: a TFDS or T2T dataset name. If it's a T2T dataset name, prefix + with "t2t_". + data_dir: data directory. + input_name: optional, name of the inputs from the dictionary. + n_chunks: optional, into how many pieces should we chunk (large inputs). + + Returns: + trax.inputs.Inputs + """ + data_dir = download_and_prepare(dataset_name, data_dir) + + (train_batches, train_eval_batches, eval_batches, + input_name, input_shape, input_dtype, + target_shape, target_dtype) = _train_and_eval_batches( + dataset_name, data_dir, input_name, n_devices) + + if isinstance(input_dtype, tf.DType): + input_dtype = input_dtype.as_numpy_dtype + if isinstance(target_dtype, tf.DType): + target_dtype = target_dtype.as_numpy_dtype + + if input_dtype == np.uint8: # TPUs don't like uint8s, we cast to ints. + input_dtype = np.int32 + if target_dtype == np.uint8: + target_dtype = np.int32 + + def numpy_stream(dataset): + return dataset_to_stream(dataset, input_name, n_chunks=n_chunks) + + if n_chunks > 0: + length = input_shape[0] + input_shape = tuple( + [tuple([length // n_chunks] + list(input_shape)[1:])] * n_chunks) + input_dtype = tuple([input_dtype] * n_chunks) + target_shape = tuple( + [tuple([length // n_chunks] + list(target_shape)[1:])] * n_chunks) + target_dtype = tuple([target_dtype] * n_chunks) + + return Inputs(train_stream=lambda: numpy_stream(train_batches), + train_eval_stream=lambda: numpy_stream(train_eval_batches), + eval_stream=lambda: numpy_stream(eval_batches), + input_shape=input_shape, input_dtype=input_dtype, + target_shape=target_shape, target_dtype=target_dtype) + + +@gin.configurable(blacklist=['n_devices']) +def random_inputs( + n_devices, + input_shape=gin.REQUIRED, input_dtype=np.int32, input_range=(0, 255), + output_shape=gin.REQUIRED, output_dtype=np.int32, output_range=(0, 9)): + """Make random Inputs for debugging. + + Args: + n_devices: how many devices to build the inputs for. + input_shape: the shape of inputs (including batch dimension). + input_dtype: the type of the inputs (int32 by default). + input_range: the range of inputs (defaults to (0, 255)). + output_shape: the shape of outputs (including batch dimension). + output_dtype: the type of the outputs (int32 by default). + output_range: the range of outputs (defaults to (0, 9)). + + Returns: + trax.inputs.Inputs + """ + if input_shape[0] % n_devices != 0: + tf.logging.fatal( + 'n_devices[%d] should divide the first dimension of input_shape[%s]', + n_devices, input_shape) + if output_shape[0] % n_devices != 0: + tf.logging.fatal( + 'n_devices[%d] should divide the first dimension of output_shape[%s]', + n_devices, output_shape) + + def random_minibatches(): + """Generate a stream of random mini-batches.""" + if input_dtype in [np.float16, np.float32, np.float64]: + rand = onp.random.uniform + else: + rand = onp.random.random_integers + while True: + inp = rand(input_range[0], input_range[1], input_shape) + inp = inp.astype(input_dtype) + out = rand(output_range[0], output_range[1], output_shape) + out = out.astype(output_dtype) + yield inp, out + + input_shape_without_batch = list(input_shape)[1:] + output_shape_without_batch = list(output_shape)[1:] + return Inputs(train_stream=random_minibatches, + train_eval_stream=random_minibatches, + eval_stream=random_minibatches, + input_shape=input_shape_without_batch, + input_dtype=input_dtype, + target_shape=output_shape_without_batch, + target_dtype=output_dtype) + + +@gin.configurable(blacklist=['n_devices']) +def sequence_copy_inputs( + n_devices, vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, + train_lengths=gin.REQUIRED, eval_lengths=gin.REQUIRED, reverse=False): + """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. + + Args: + n_devices: how many devices to build the inputs for. + vocab_size: how many symbols to use. + batch_size: how large are the batches. + train_lengths: lengths of w for training. + eval_lengths: lengths of w for eval. + reverse: bool (optional, false by default): reverse the second sequence. + + Returns: + trax.inputs.Inputs + """ + assert batch_size % n_devices == 0 + def random_minibatches(length_list): + """Generate a stream of random mini-batches.""" + while True: + length = random.choice(length_list) + assert length % 2 == 0 + w_length = (length // 2) - 1 + w = onp.random.randint(low=1, high=vocab_size-1, + size=(batch_size, w_length)) + zero = onp.zeros([batch_size, 1], onp.int32) + loss_weights = onp.concatenate([onp.zeros((batch_size, w_length+2)), + onp.ones((batch_size, w_length))], axis=1) + if reverse: + x = onp.concatenate([zero, w, zero, np.flip(w, axis=1)], axis=1) + else: + x = onp.concatenate([zero, w, zero, w], axis=1) + yield (x, x, loss_weights) # Here inputs and targets are the same. + + # If there's only one length, make the shape known. + example_length = None + if (len(train_lengths) == 1 and len(eval_lengths) == 1 and + train_lengths[0] == eval_lengths[0]): + example_length = train_lengths[0] + + return Inputs( + train_stream=lambda: random_minibatches(train_lengths), + train_eval_stream=lambda: random_minibatches(train_lengths), + eval_stream=lambda: random_minibatches(eval_lengths), + input_shape=(example_length,), + input_dtype=onp.int32, + target_shape=(example_length,), + target_dtype=onp.int32) + + +def dataset_to_stream(dataset, input_name, n_chunks=0): + """Takes a tf.Dataset and creates a numpy stream of ready batches.""" + for example in backend.dataset_as_numpy(dataset): + features = example[0] + inp, out = features[input_name], example[1] + mask = features['mask'] if 'mask' in features else None + # All input-pipeline processing should be on CPU. + with tf.device('cpu:0'): + # Some accelerators don't handle uint8 well, cast to int. + if isinstance(inp, np.uint8): + inp = inp.astype(np.int32) + if isinstance(out, np.uint8): + out = out.astype(np.int32) + if len(out.shape) > 1 and out.shape[-1] == 1: + out = np.squeeze(out, axis=-1) + if n_chunks > 0: + inp = tuple(np.split(inp, n_chunks, axis=1)) + out = tuple(np.split(out, n_chunks, axis=1)) + yield (inp, out) if mask is None else (inp, out, mask) + + +@gin.configurable(whitelist=['train_shuffle_files', 'eval_shuffle_files', + 'eval_holdout_size']) +def train_and_eval_dataset(dataset_name, data_dir, eval_holdout_size=0, + train_shuffle_files=True, eval_shuffle_files=False): + """Return train and evaluation datasets, feature info and supervised keys. + + Args: + dataset_name: a string, the name of the dataset; if it starts with "t2t_" + then we'll search T2T Problem registry for it, otherwise we assume it + is a dataset from TFDS and load it from there. + data_dir: directory where the data is located. + eval_holdout_size: float from 0 to <1; if >0 use this much of training data + for evaluation (instead of looking for a pre-specified VALIDATION split). + train_shuffle_files: Boolean determining whether or not to shuffle the train + files at startup. Set to False if you want data determinism. + eval_shuffle_files: Boolean determining whether or not to shuffle the test + files at startup. Set to False if you want data determinism. + + Returns: + a 4-tuple consisting of: + * the train tf.Dataset + * the eval tf.Dataset + * information about features: a python dictionary with feature names + as keys and an object as value that provides .shape and .n_classes. + * supervised_keys: information what's the input and what's the target, + ie., a pair of lists with input and target feature names. + """ + if dataset_name.startswith('t2t_'): + return _train_and_eval_dataset_v1(dataset_name[4:], data_dir) + dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) + info = dataset_builder.info + splits = dataset_builder.info.splits + if tfds.Split.TRAIN not in splits: + raise ValueError('To train we require a train split in the dataset.') + train_split = tfds.Split.TRAIN + if eval_holdout_size > 0: + holdout_percentage = int(eval_holdout_size * 100.0) + train_percentage = 100 - holdout_percentage + train_split = tfds.Split.TRAIN.subsplit(tfds.percent[:train_percentage]) + eval_split = tfds.Split.TRAIN.subsplit(tfds.percent[train_percentage:]) + else: + if tfds.Split.VALIDATION not in splits and 'test' not in splits: + raise ValueError('We require a validation or test split in the dataset.') + eval_split = tfds.Split.VALIDATION + if tfds.Split.VALIDATION not in splits: + eval_split = tfds.Split.TEST + train = tfds.load( + name=dataset_name, split=train_split, data_dir=data_dir, + shuffle_files=train_shuffle_files) + valid = tfds.load( + name=dataset_name, split=eval_split, data_dir=data_dir, + shuffle_files=eval_shuffle_files) + keys = None + if info.supervised_keys: + keys = ([info.supervised_keys[0]], [info.supervised_keys[1]]) + return train, valid, info.features, keys + + +def _make_info(shape_list, n_classes, dtype): + """Create an info-like tuple for feature given some shapes and vocab size.""" + feature_info = collections.namedtuple( + 'FeatureInfo', ['shape', 'n_classes', 'dtype']) + cur_shape = list(shape_list[0]) + # We need to merge the provided shapes, put None where they disagree. + for shape in shape_list: + if len(shape) != len(cur_shape): + raise ValueError('Shapes need to have the same number of dimensions.') + for i in range(len(shape)): + if cur_shape[i] is not None: + if shape[i] != cur_shape[i]: + cur_shape[i] = None + return feature_info(cur_shape, n_classes, dtype) + + +def _select_features(example, feature_list=None): + """Select a subset of features from the example dict.""" + feature_list = feature_list or ['inputs', 'targets'] + return {f: example[f] for f in feature_list if f in example} + + +def _eager_dataset_iterator(dataset): + for item in dataset: + flat = tf.nest.flatten(item) + flat = [el.numpy() for el in flat] + yield tf.nest.pack_sequence_as(item, flat) + + +def _train_and_eval_dataset_v1(problem_name, data_dir): + """Return train and evaluation datasets, feature info and supervised keys.""" + with tf.device('cpu:0'): + problem = t2t_problems.problem(problem_name) + train_dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, data_dir) + train_dataset = train_dataset.map(_select_features) + eval_dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, data_dir) + eval_dataset = eval_dataset.map(_select_features) + hparams = problem.get_hparams() + # We take a few training examples to guess the shapes. + input_shapes, target_shapes, examples = [], [], [] + if tf.executing_eagerly(): + for example in _eager_dataset_iterator(train_dataset.take(3)): + examples.append(example) + else: + example_tensor = train_dataset.make_one_shot_iterator().get_next() + sess = tf.Session() + example1 = sess.run(example_tensor) + example2 = sess.run(example_tensor) + example3 = sess.run(example_tensor) + examples = [example1, example2, example3] + # We use 'inputs' as input except for purely auto-regressive tasks like + # language models where 'targets' are used as input_key. + input_key = 'inputs' if 'inputs' in examples[0] else 'targets' + supervised_keys = ([input_key], ['targets']) + for example in examples: + input_shapes.append(list(example[input_key].shape)) + target_shapes.append(list(example['targets'].shape)) + input_vocab_size = hparams.vocab_size[input_key] + target_vocab_size = hparams.vocab_size['targets'] + input_dtype = examples[0][input_key].dtype + target_dtype = examples[0]['targets'].dtype + input_info = _make_info(input_shapes, input_vocab_size, input_dtype) + target_info = _make_info(target_shapes, target_vocab_size, target_dtype) + info = {input_key: input_info, 'targets': target_info} + return train_dataset, eval_dataset, info, supervised_keys + + +@gin.configurable(blacklist=['dataset', 'training', 'shapes', + 'target_names', 'n_devices']) +def batch_fn(dataset, training, shapes, target_names, n_devices, + batch_size_per_device=32, batch_size=None, eval_batch_size=32, + bucket_length=32, buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=128, max_eval_length=None): + """Batching function.""" + del target_names + # Batch size is batch_size_per_device * n_devices unless given directly. + batch_size = batch_size or batch_size_per_device * n_devices + # If bucketing is not specified, check if target shapes are variable. + cur_batch_size = batch_size if training else eval_batch_size + # Make cur_batch_size divisible by n_devices. + cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices + # Create heuristic buckets is none are specified. + if buckets is None: + variable_target_shapes = False + target_shape = shapes[1] + for dim in target_shape: + if dim is None: + variable_target_shapes = True + tf.logging.info('Heuristically setting bucketing to %s based on shapes ' + 'of target tensors.' % variable_target_shapes) + if variable_target_shapes: + bucket_boundaries = [bucket_length // 4, bucket_length // 2, + bucket_length, bucket_length * 2, + bucket_length * 4, bucket_length * 8, + bucket_length * 16] + if not training: + max_eval_length = max_eval_length or bucket_length * 32 + bucket_boundaries[-1] = max_eval_length + # We will pad to boundaries which pads to bucket_boundary - 1: add 1 here. + bucket_boundaries = [b + 1 for b in bucket_boundaries] + bucket_batch_sizes = [cur_batch_size * 4, cur_batch_size * 2, + cur_batch_size, cur_batch_size // 2, + cur_batch_size // 4, cur_batch_size // 8, + cur_batch_size // 16, 1] + if not training: + bucket_batch_sizes[-2] = cur_batch_size // max_eval_length + # Make batch sizes divisible by n_devices. + bucket_batch_sizes = [max(b // n_devices, 1) * n_devices + for b in bucket_batch_sizes] + buckets = (bucket_boundaries, bucket_batch_sizes) + + if buckets: + tf.logging.info('Bucketing with buckets %s.' % str(buckets)) + def example_length(example_inputs, target): + """The length function used by bucket_by_sequence_length to bucket.""" + other_length = 0 + if buckets_include_inputs_in_length: + other_length = tf.shape(example_inputs['inputs'])[0] + return tf.maximum(tf.shape(target)[0], other_length) + boundaries, batch_sizes = buckets + dataset = dataset.apply(tf.data.experimental.bucket_by_sequence_length( + example_length, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + else: + dataset = dataset.padded_batch(cur_batch_size, shapes) + if training: + return dataset.shuffle(batch_shuffle_size) + return dataset + + +@gin.configurable(blacklist=['dataset', 'training']) +def cifar10_no_augmentation_preprocess(dataset, training): + del training + + def cast_image(features, targets): + features['image'] = tf.cast(features['image'], tf.float32) / 255.0 + return features, targets + + dataset = dataset.map(cast_image) + return dataset + + +@gin.configurable(blacklist=['dataset', 'training']) +def cifar10_augmentation_preprocess(dataset, training): + """Preprocessing for cifar10 with augmentation (see below).""" + + def augment_image(image): + """Image augmentation suitable for CIFAR-10/100. + + As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). + + Args: + image: a Tensor. + Returns: + Tensor of the same shape as image. + """ + image = tf.image.resize_image_with_crop_or_pad(image, 40, 40) + image = tf.random_crop(image, [32, 32, 3]) + image = tf.image.random_flip_left_right(image) + return image + + def augment(features, targets): + features['image'] = augment_image(features['image']) + return features, targets + + def cast_image(features, targets): + features['image'] = tf.cast(features['image'], tf.float32) / 255.0 + return features, targets + + if training: + dataset = dataset.map(augment) + dataset = dataset.map(cast_image) + return dataset + + +def no_preprocess(dataset, training): + del training + return dataset + + +@gin.configurable(blacklist=['dataset', 'training']) +def concat_preprocess(dataset, training, pad_symbol=0): + """Pre-processing function that concatenates input and target for LM.""" + del training + + def concat(features, targets): + inp = features['inputs'] + pad = tf.expand_dims(tf.zeros_like(inp[0]) + pad_symbol, axis=0) + concat = tf.concat([pad, inp, pad, targets], axis=0) + # Note: we're updating existing features dictionary here, so make sure + # it is not re-used in some other ways outside of this function. + features['inputs'] = concat + return features, concat + + dataset = dataset.map(concat) + return dataset + + +@gin.configurable(blacklist=['dataset', 'training']) +def lm1b_preprocess(dataset, training, + max_target_length=-1, max_eval_target_length=-1): + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" + + def target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_target_length + 1) + + def eval_target_right_length(_, target): + return tf.less(tf.shape(target)[0], max_eval_target_length + 1) + + if max_target_length > 0 and training: + dataset = dataset.filter(target_right_length) + + if max_eval_target_length > 0 and not training: + dataset = dataset.filter(eval_target_right_length) + + return dataset + + +# TODO(lukaszkaiser): find a single more abstract way of text pre-processing. +@gin.configurable(blacklist=['dataset', 'training']) +def wmt_preprocess(dataset, training, max_length=-1, max_eval_length=-1): + """Preprocessing for LM1B: filter out targets exceeding maximum length.""" + + def train_right_length(example, target): + l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) + return tf.less(l, max_length + 1) + + def eval_right_length(example, target): + l = tf.maximum(tf.shape(example['inputs'])[0], tf.shape(target)[0]) + return tf.less(l, max_eval_length + 1) + + if max_length > 0 and training: + dataset = dataset.filter(train_right_length) + + if max_eval_length > 0 and not training: + dataset = dataset.filter(eval_right_length) + + return dataset + + +@gin.configurable(blacklist=['dataset', 'training']) +def wmt_concat_preprocess(dataset, training, max_length=-1, max_eval_length=-1): + """Preprocessing for WMT: filter exceeding maximum length and concatenate.""" + dataset = wmt_preprocess(dataset, training, max_length, max_eval_length) + + def concat_and_add_mask(features, targets): + inp = features['inputs'] + pad = tf.expand_dims(tf.zeros_like(inp[0]), axis=0) + concat = tf.concat([inp, pad, targets], axis=0) + mask = tf.concat([tf.zeros_like(inp), pad, tf.ones_like(targets)], axis=0) + features['inputs'] = concat + features['mask'] = mask + return features, concat + + dataset = dataset.map(concat_and_add_mask) + return dataset + + +@gin.configurable(whitelist=['preprocess_fun', 'shuffle_buffer_size']) +def shuffle_and_batch_data(dataset, + target_names, + features_info, + training, + n_devices, + shuffle_buffer_size=1024, + preprocess_fun=no_preprocess): + """Shuffle and batch the given dataset.""" + def append_targets(example): + """Append targets to the example dictionary. Needed for Keras.""" + if len(target_names) == 1: + return (example, example[target_names[0]]) + targets = {} + for name in target_names: + targets[name] = example[name] + return (example, targets) + dataset = dataset.map(append_targets) + # TODO(pkozakowski): Repeat both the training and evaluation set, so we don't + # have incomplete batches during evaluation. This will be a problem when we + # add an option to evaluate on the whole dataset, then we'll need to think of + # a different solution. + dataset = dataset.repeat() + if training: + # Skip a random fraction at the beginning of the stream. The skip is + # essential for synchronous highly-parallel training to avoid multiple + # replicas reading the same data in lock-step. + dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) + dataset = preprocess_fun(dataset, training) + shapes = {k: features_info[k].shape for k in features_info} + shapes = (shapes, shapes[target_names[0]]) + dataset = dataset.shuffle(shuffle_buffer_size) + dataset = batch_fn(dataset, training, shapes, target_names, n_devices) + return dataset.prefetch(2) + + +def _train_and_eval_batches(dataset, data_dir, input_name, n_devices): + """Return train and eval batches with input name and shape.""" + (train_data, eval_data, features_info, keys) = train_and_eval_dataset( + dataset, data_dir) + input_names, target_names = keys[0], keys[1] + train_batches = shuffle_and_batch_data( + train_data, target_names, features_info, training=True, + n_devices=n_devices) + train_eval_batches = shuffle_and_batch_data( # Data for eval-on-train. + train_data, target_names, features_info, training=False, + n_devices=n_devices) + eval_batches = shuffle_and_batch_data( + eval_data, target_names, features_info, training=False, + n_devices=n_devices) + input_name = input_name or input_names[0] + input_shape = features_info[input_name].shape + input_dtype = features_info[input_name].dtype + target_shape = features_info[target_names[0]].shape + target_dtype = features_info[target_names[0]].dtype + return (train_batches, train_eval_batches, eval_batches, + input_name, list(input_shape), input_dtype, + list(target_shape), target_dtype) diff --git a/trax/inputs_test.py b/trax/inputs_test.py new file mode 100644 index 000000000..67052b151 --- /dev/null +++ b/trax/inputs_test.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.inputs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from trax import inputs + + +def test_dataset_ints(lengths): + """Create a test dataset of int64 tensors of shape [length].""" + def generator(): + """Sample generator of sequences of shape [length] of type int64.""" + for length in lengths: + x = np.zeros([length], dtype=np.int64) + yield (x, x) # Inputs and targets are the same here. + types = (tf.int64, tf.int64) + shapes = (tf.TensorShape([None]), tf.TensorShape([None])) + return tf.data.Dataset.from_generator( + generator, output_types=types, output_shapes=shapes) + + +class InputsTest(tf.test.TestCase): + + def setUp(self): + gin.clear_config() + + def test_batch_fn(self): + dataset = test_dataset_ints([32]) + dataset = dataset.repeat(10) + batches = inputs.batch_fn( + dataset, True, ([None], [None]), [], 1, batch_size=10) + count = 0 + for example in tfds.as_numpy(batches): + count += 1 + self.assertEqual(example[0].shape[0], 10) # Batch size = 10. + self.assertEqual(count, 1) # Just one batch here. + + def test_batch_fn_n_devices(self): + dataset = test_dataset_ints([32]) + dataset = dataset.repeat(9) + batches = inputs.batch_fn( + dataset, True, ([None], [None]), [], 9, batch_size=10) + count = 0 + for example in tfds.as_numpy(batches): + count += 1 + # Batch size adjusted to be divisible by n_devices. + self.assertEqual(example[0].shape[0], 9) + self.assertEqual(count, 1) # Just one batch here. + + +if __name__ == "__main__": + tf.test.main() diff --git a/trax/jaxboard.py b/trax/jaxboard.py new file mode 100644 index 000000000..43555cb0a --- /dev/null +++ b/trax/jaxboard.py @@ -0,0 +1,350 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Write Summaries from JAX for use with Tensorboard. + +See jaxboard_demo.py for example usage. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import struct +import time +import warnings +import wave +import matplotlib as mpl +# Necessary to prevent attempted Tk import: +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + mpl.use('Agg') +# pylint: disable=g-import-not-at-top +import matplotlib.pyplot as plt +import numpy as onp +import tensorflow as tf +from tensorflow import HistogramProto +from tensorflow import Summary +from tensorflow import SummaryMetadata +from tensorflow.io import gfile + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.core.util import event_pb2 +from tensorflow.python.summary.writer.event_file_writer import EventFileWriter +# pylint: enable=g-direct-tensorflow-import + + +def _pack_images(images, rows, cols): + """Helper utility to make a tiled field of images from numpy arrays. + + Args: + images: Image tensor in shape [N, W, H, C]. + rows: Number of images per row in tiled image. + cols: Number of images per column in tiled image. + + Returns: + A tiled image of shape [W * rows, H * cols, C]. + Truncates incomplete rows. + """ + shape = onp.shape(images) + width, height, depth = shape[-3:] + images = onp.reshape(images, (-1, width, height, depth)) + batch = onp.shape(images)[0] + rows = onp.minimum(rows, batch) + cols = onp.minimum(batch // rows, cols) + images = images[:rows * cols] + images = onp.reshape(images, (rows, cols, width, height, depth)) + images = onp.transpose(images, [0, 2, 1, 3, 4]) + images = onp.reshape(images, [rows * width, cols * height, depth]) + return images + + +class SummaryWriter(object): + """Saves data in event and summary protos for tensorboard.""" + + def __init__(self, log_dir): + """Create a new SummaryWriter. + + Args: + log_dir: path to record tfevents files in. + """ + # If needed, create log_dir directory as well as missing parent directories. + if not gfile.isdir(log_dir): + gfile.makedirs(log_dir) + + self._event_writer = EventFileWriter(log_dir, 10, 120, None) + self._step = 0 + self._closed = False + + def add_summary(self, summary, step): + event = event_pb2.Event(summary=summary) + event.wall_time = time.time() + if step is not None: + event.step = int(step) + self._event_writer.add_event(event) + + def close(self): + """Close SummaryWriter. Final!""" + if not self._closed: + self._event_writer.close() + self._closed = True + del self._event_writer + + def __del__(self): # safe? + self.close() + + def flush(self): + self._event_writer.flush() + + def scalar(self, tag, value, step=None): + """Saves scalar value. + + Args: + tag: str: label for this data + value: int/float: number to log + step: int: training step + """ + value = float(onp.array(value)) + if step is None: + step = self._step + else: + self._step = step + summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) + self.add_summary(summary, step) + + def image(self, tag, image, step=None): + """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. + + Args: + tag: str: label for this data + image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ + step: int: training step + """ + image = onp.array(image) + if step is None: + step = self._step + else: + self._step = step + if len(onp.shape(image)) == 2: + image = image[:, :, onp.newaxis] + if onp.shape(image)[-1] == 1: + image = onp.repeat(image, 3, axis=-1) + image_strio = io.BytesIO() + plt.imsave(image_strio, image, format='png') + image_summary = Summary.Image( + encoded_image_string=image_strio.getvalue(), + colorspace=3, + height=image.shape[0], + width=image.shape[1]) + summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) + self.add_summary(summary, step) + + def images(self, tag, images, step=None, rows=None, cols=None): + """Saves (rows, cols) tiled images from onp.ndarray. + + If either rows or cols aren't given, they are determined automatically + from the size of the image batch, if neither are given a long column + of images is produced. This truncates the image batch rather than padding + if it doesn't fill the final row. + + Args: + tag: str: label for this data + images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d + step: int: training step + rows: int: number of rows in tile + cols: int: number of columns in tile + """ + images = onp.array(images) + if step is None: + step = self._step + else: + self._step = step + n_images = onp.shape(images)[0] + if rows is None and cols is None: + rows = 1 + cols = n_images + elif rows is None: + rows = n_images // cols + elif cols is None: + cols = n_images // rows + tiled_images = _pack_images(images, rows, cols) + self.image(tag, tiled_images, step=step) + + def plot(self, tag, mpl_plt, step=None, close_plot=True): + """Saves matplotlib plot output to summary image. + + Args: + tag: str: label for this data + mpl_plt: matplotlib stateful pyplot object with prepared plotting state + step: int: training step + close_plot: bool: automatically closes plot + """ + if step is None: + step = self._step + else: + self._step = step + fig = mpl_plt.get_current_fig_manager() + img_w, img_h = fig.canvas.get_width_height() + image_buf = io.BytesIO() + mpl_plt.savefig(image_buf, format='png') + image_summary = Summary.Image( + encoded_image_string=image_buf.getvalue(), + colorspace=4, # RGBA + height=img_h, + width=img_w) + summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) + self.add_summary(summary, step) + if close_plot: + mpl_plt.close() + + def audio(self, tag, audiodata, step=None, sample_rate=44100): + """Saves audio. + + NB: single channel only right now. + + Args: + tag: str: label for this data + audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave + step: int: training step + sample_rate: sample rate of passed in audio buffer + """ + audiodata = onp.array(audiodata) + if step is None: + step = self._step + else: + self._step = step + audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) + if audiodata.ndim != 1: + raise ValueError('Audio data must be 1D.') + sample_list = (32767.0 * audiodata).astype(int).tolist() + wio = io.BytesIO() + wav_buf = wave.open(wio, 'wb') + wav_buf.setnchannels(1) + wav_buf.setsampwidth(2) + wav_buf.setframerate(sample_rate) + enc = b''.join([struct.pack(' 0 else onp.concatenate([[0], counts[:end]])) + limits = limits[start:end + 1] + sum_sq = values.dot(values) + histo = HistogramProto( + min=values.min(), + max=values.max(), + num=len(values), + sum=values.sum(), + sum_squares=sum_sq, + bucket_limit=limits.tolist(), + bucket=counts.tolist()) + summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) + self.add_summary(summary, step) + + def text(self, tag, textdata, step=None): + """Saves a text summary. + + Args: + tag: str: label for this data + textdata: string, or 1D/2D list/numpy array of strings + step: int: training step + Note: markdown formatting is rendered by tensorboard. + """ + if step is None: + step = self._step + else: + self._step = step + smd = SummaryMetadata( + plugin_data=SummaryMetadata.PluginData(plugin_name='text')) + if isinstance(textdata, (str, bytes)): + tensor = tf.make_tensor_proto( + values=[textdata.encode(encoding='utf_8')], shape=(1,)) + else: + textdata = onp.array(textdata) # convert lists, jax arrays, etc. + datashape = onp.shape(textdata) + if len(datashape) == 1: + tensor = tf.make_tensor_proto( + values=[td.encode(encoding='utf_8') for td in textdata], + shape=(datashape[0],)) + elif len(datashape) == 2: + tensor = tf.make_tensor_proto( + values=[ + td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) + ], + shape=(datashape[0], datashape[1])) + summary = Summary( + value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) + self.add_summary(summary, step) + + +# Copied from gin/tf/utils.py:GinConfigSaverHook +def markdownify_operative_config_str(string): + """Convert an operative config string to markdown format.""" + + # TODO(b/37527917): Total hack below. Implement more principled formatting. + def process(line): + """Convert a single line to markdown format.""" + if not line.startswith('#'): + return ' ' + line + + line = line[2:] + if line.startswith('===='): + return '' + if line.startswith('None'): + return ' # None.' + if line.endswith(':'): + return '#### ' + line + return line + + output_lines = [] + for line in string.splitlines(): + procd_line = process(line) + if procd_line is not None: + output_lines.append(procd_line) + + return '\n'.join(output_lines) diff --git a/trax/layers/README.md b/trax/layers/README.md new file mode 100644 index 000000000..4782fa2dc --- /dev/null +++ b/trax/layers/README.md @@ -0,0 +1,60 @@ +# Trax Layers + + + +## Base layer structure + +All layers inherit from the Layer class and generally need to implement 2 +methods: + +```python +def forward(self, inputs, params=(), state=(), **kwargs): + """Computes the layer's output as part of a forward pass through the model.""" + +def new_params_and_state(self, input_shape, input_dtype, rng): + """Returns a (params, state) pair suitable for initializing this layer.""" +``` + +The base Layer class wraps these functions and provides initialization +and call functions to be used as follows. + +```python +layer = MyLayer() +x = np.zeros(10) +rng = random.get_prng(0) +layer.initialize_once(x.shape, x.dtype, rng) +output = layer(x) +``` + +## Decorator + +To create simple layers, especially ones without parameters, use the layer +decorator. + +```python +@base.layer() +def Relu(x, **unused_kwargs): + return np.maximum(x, 0.) +``` + +## Parameter sharing + +Parameters are shared when the same layer object is used. + +```python +standard_mlp = layers.Serial(layers.Dense(10), layers.Dense(10)) +layer = Dense(10) +shared_parameters_mlp = layers.Serial(layer, layer) +``` +For this reason, if you call `layer.initialize_once(...)` for the second time +on an already initialized layer, it will not re-initialize the layer. + +## Core layers + +* Dense +* Conv + +## Layer composition + +* Serial +* Parallel diff --git a/trax/layers/__init__.py b/trax/layers/__init__.py new file mode 100644 index 000000000..898e96286 --- /dev/null +++ b/trax/layers/__init__.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layers defined in trax.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin +# We create a flat layers.* namespace for uniform calling conventions as we +# upstream changes. +# pylint: disable=wildcard-import +from trax.layers.attention import * +from trax.layers.base import * +from trax.layers.combinators import * +from trax.layers.convolution import * +from trax.layers.core import * +from trax.layers.initializers import * +from trax.layers.metrics import * +from trax.layers.normalization import * +from trax.layers.pooling import * +from trax.layers.reversible import * +from trax.layers.rnn import * + + +# Ginify +def layer_configure(*args, **kwargs): + kwargs["module"] = "trax.layers" + return gin.external_configurable(*args, **kwargs) + +# pylint: disable=used-before-assignment +# pylint: disable=invalid-name +Relu = layer_configure(Relu) +Sigmoid = layer_configure(Sigmoid) +Tanh = layer_configure(Tanh) +HardSigmoid = layer_configure(HardSigmoid) +HardTanh = layer_configure(HardTanh) +Exp = layer_configure(Exp) +LogSoftmax = layer_configure(LogSoftmax) +Softmax = layer_configure(Softmax) +Softplus = layer_configure(Softplus) + +DotProductCausalAttention = layer_configure( + DotProductCausalAttention, blacklist=["mode"]) +MemoryEfficientCausalAttention = layer_configure( + MemoryEfficientCausalAttention, blacklist=["mode"]) +TimeBinCausalAttention = layer_configure( + TimeBinCausalAttention, blacklist=["mode"]) +LSHCausalAttention = layer_configure( + LSHCausalAttention, blacklist=["mode"]) diff --git a/trax/layers/attention.py b/trax/layers/attention.py new file mode 100644 index 000000000..3d1f5c488 --- /dev/null +++ b/trax/layers/attention.py @@ -0,0 +1,1355 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attention Layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import random + +import jax +import numpy as onp + +from trax import backend +from trax.backend import numpy as np +from trax.layers import base +from trax.layers import combinators as cb +from trax.layers import core +from trax.layers import initializers as init + + +# Layers are always CamelCase, but functions in general are snake_case +# pylint: disable=invalid-name + + +@base.layer() +def ShiftRight(x, mode='train', **unused_kwargs): + """Layer to shift the tensor to the right by padding on axis 1.""" + if mode == 'predict': + # Do nothing in predict mode, as then the sequence length is 1. + return x + + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[1] = (1, 0) # Padding on axis=1 + padded = np.pad(x, pad_widths, mode='constant', + constant_values=x.dtype.type(0)) + return padded[:, :-1] + + +@base.layer() +def CausalMask(x, params, axis=-1, **kwargs): + del params, kwargs + size = x.shape[axis] + return onp.tril(onp.ones((1, size, size), dtype=onp.bool_), k=0) + + +@base.layer() +def PaddingMask(x, params, pad=0, **kwargs): + del params, kwargs + return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1])) + + +@base.layer(n_inputs=2) +def EncoderDecoderMask(x, **unused_kwargs): + """Makes encoder-decoder mask from decoder input and a padding mask.""" + decoder_input, padding_mask = x + padding_mask = np.reshape( + padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) + # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. + return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1)) + + +class PositionalEncoding(base.Layer): + """Implements bare positional encoding.""" + + def __init__(self, max_len=2048, mode='train'): + super(PositionalEncoding, self).__init__() + self._max_len = max_len + self._mode = mode + + def forward(self, inputs, params=(), state=(), **kwargs): + if self._mode in ('train', 'eval'): + x = inputs + symbol_size = np.shape(x)[1] + return (x + params[:, :symbol_size, :], state) + else: + assert self._mode == 'predict' + # Fast inference: return consectutive elements of the encoding sequence, + # storing the index in state. + return (inputs + np.expand_dims(params[:, state, :], 1), state + 1) + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype, rng + d_feature = input_shape[-1] + pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32) + position = onp.arange(0, self._max_len)[:, onp.newaxis] + div_term = onp.exp( + onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature)) + pe[:, 0::2] = onp.sin(position * div_term) + pe[:, 1::2] = onp.cos(position * div_term) + pe = pe[onp.newaxis, :, :] # [1, self._max_len, d_feature] + params = np.array(pe) # These are trainable parameters, initialized above. + state = 0 if self._mode == 'predict' else () + return params, state + + +def DotProductAttention(query, key, value, mask, dropout, mode, rng): + """Core dot product self-attention. + + Args: + query: array of representations + key: array of representations + value: array of representations + mask: attention-mask, gates attention + dropout: float: dropout rate + mode: 'eval' or 'train': whether to use dropout + rng: JAX PRNGKey: subkey for disposable use + + Returns: + Self attention for q, k, v arrays. + """ + depth = np.shape(query)[-1] + dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) + if mask is not None: + # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 + # We must ensure that both mask and the -1e9 constant have a data dependency + # on the input. Broadcasted copies of these use a lot of memory, so they + # should be computed at runtime (rather than being global constants). + if backend.get_name() == 'jax': + mask = jax.lax.tie_in(dots, mask) + dots = np.where(mask, dots, np.full_like(dots, -1e9)) + # Softmax. + dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) + if dropout >= 1.0: + raise ValueError('Dropout rates must be lower than 1.') + if dropout is not None and dropout > 0.0 and mode == 'train': + keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape) + dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots)) + out = np.matmul(dots, value) + return out + + +@base.layer(n_inputs=4, n_outputs=2) +def PureAttention(x, params, n_heads=1, dropout=0.0, mode='train', **kwargs): + """Pure transformer-style multi-headed attention. + + Args: + x: inputs (q, k, v, mask) + params: parameters (none) + n_heads: int: number of attention heads + dropout: float: dropout rate + mode: str: 'train' or 'eval' + **kwargs: other arguments including the rng + + Returns: + Pure Multi-headed attention result, and the mask. + """ + del params + rng = kwargs.get('rng', None) + q, k, v, mask = x + d_feature = q.shape[-1] + assert d_feature % n_heads == 0 + d_head = d_feature // n_heads + nbatch = np.shape(q)[0] + # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head + def SplitHeads(x): + return np.transpose( + np.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3)) + # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature + def JoinHeads(x): # pylint: disable=invalid-name + return np.reshape( + np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head)) + # Split heads, dot-product attention, rejoin heads. + res = JoinHeads( + DotProductAttention( + SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, + dropout=dropout, mode=mode, rng=rng)) + return res, mask # Keep the mask. + + +def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'): + """Transformer-style multi-headed attention. + + Accepts inputs of the form q, k, v, mask. + + Args: + d_feature: int: dimensionality of feature embedding + n_heads: int: number of attention heads + dropout: float: dropout rate + mode: str: 'train' or 'eval' + + Returns: + Multi-headed self-attention result and the mask. + """ + return [ + cb.Parallel( + core.Dense(d_feature), + core.Dense(d_feature), + core.Dense(d_feature), + ), + PureAttention( # pylint: disable=no-value-for-parameter + n_heads=n_heads, dropout=dropout, mode=mode), + core.Dense(d_feature), + ] + + +def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'): + """Transformer-style multi-headed attention. + + Accepts inputs of the form (x, mask) and constructs (q, k, v) from x. + + Args: + d_feature: int: dimensionality of feature embedding + n_heads: int: number of attention heads + dropout: float: dropout rate + mode: str: 'train' or 'eval' + + Returns: + Multi-headed self-attention result and the mask. + """ + return [ + cb.Dup(), cb.Dup(), + AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), + ] + + +def BasicCausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'): + """Transformer-style multi-headed causal attention. + + This implementation is less configurable than the CausalAttention layer + defined below, but it shares code with the non-causal attention. + + # TODO(jonni,lukaszkaiser): standardize and improve layer comments. + Accepts inputs of the form x and constructs (q, k, v) and causal mask from x. + + Args: + d_feature: int: dimensionality of feature embedding + n_heads: int: number of attention heads + dropout: float: dropout rate + mode: str: 'train' or 'eval' + + Returns: + Multi-headed self-attention result. + """ + return [ + cb.Dup(), + cb.Parallel([], CausalMask(axis=-2)), # pylint: disable=no-value-for-parameter + Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), + cb.Parallel([], cb.Drop()), # x + ] + + +class ShiftRightLearned(base.Layer): + """Layer constructor function for shifting right by a learned vector.""" + + def __init__(self, initializer=init.RandomNormalInitializer(0.01)): + super(ShiftRightLearned, self).__init__() + self._initializer = initializer + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + c = backend.numpy.reshape(params, [1, 1, -1]) + c += backend.numpy.zeros((x.shape[0], 1, x.shape[2]), dtype=x.dtype) + return backend.numpy.concatenate([c, x], axis=1)[:, :-1, :], state + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype + b = self._initializer((input_shape[-1],), rng) + return b, () + + +class ComputeAttentionHeads(base.Layer): + """Computes queries/keys/values via linear projection. + + The output shape is (n_batch * n_heads, seqlen, d_head); the batch and head + dimensions are fused to allow for more efficient memory layouts. + """ + + def __init__(self, n_heads=1, d_head=64, + kernel_initializer=init.GlorotUniformInitializer()): + super(ComputeAttentionHeads, self).__init__() + self._n_heads = n_heads + self._d_head = d_head + self._kernel_initializer = kernel_initializer + # The lack of a bias term here is consistent with the tensor2tensor + # implementation, and shouldn't have an effect on modeling quality. + # Note that AttentionQKV above is different in that it uses a bias term. + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + seqlen = x.shape[1] + res = np.dot(x, params) + + # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head + res = np.reshape(res, (x.shape[0], seqlen, self._n_heads, self._d_head)) + # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head + res = np.transpose(res, (0, 2, 1, 3)) + # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head + res = np.reshape(res, (-1, seqlen, self._d_head)) + + return res, state + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype + w = self._kernel_initializer( + (input_shape[-1], self._n_heads * self._d_head), rng) + return w, () + + +class ComputeAttentionOutput(base.Layer): + """Joins outputs from different heads via linear projection.""" + + def __init__(self, n_heads=1, d_model=1024, + kernel_initializer=init.GlorotUniformInitializer()): + super(ComputeAttentionOutput, self).__init__() + self._n_heads = n_heads + self._d_model = d_model + self._kernel_initializer = kernel_initializer + # The lack of a bias term here is consistent with the tensor2tensor + # implementation, and shouldn't have an effect on modeling quality. + # Note that AttentionQKV above is different in that it uses a bias term. + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + seqlen = x.shape[1] + d_head = x.shape[2] + + x = np.reshape(x, (-1, self._n_heads, seqlen, d_head)) + x = np.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head + x = np.reshape(x, (-1, seqlen, self._n_heads * d_head)) + + return np.dot(x, params), state + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype + w = self._kernel_initializer( + (input_shape[-1] * self._n_heads, self._d_model), rng) + return w, () + + +class BaseCausalAttention(base.Layer): + """Base class for variants of causal self-attention.""" + + def __init__(self, mode='train'): + del mode + super(BaseCausalAttention, self).__init__(n_inputs=3) + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + """Forward pass for the attention layer.""" + raise NotImplementedError() + + def forward_and_backward(self, inputs, grad, **kwargs): + """Performs both forward and backward pass for the attention layer. + + This is used in reversible models: for the backward pass of a reversible + model, we need to compute both the forward direction (to recover the + previous layer's activations) and the backward direction simultaneously. + Some computation can be shared between the forward and backward directions, + which makes it more efficient to implement them jointly. + + This method assumes that the layer is stateless and has no parameters. + + Args: + inputs: A tuple (q, k, v), where each element has shape + n_batch*n_heads, seqlen, d_head + grad: gradient signal for the layer output. + **kwargs: kwargs for the layer + + Returns: + A nested-tuple structure (output, (q_grad, k_grad, v_grad)) that contains + the output of the forward pass and the gradient signal for each input. + """ + raise NotImplementedError() + + +def _fast_inference_init_state(input_shapes, input_dtypes, buffer_length): + """Initializes state of a causal attention layer for fast inference.""" + ((batch_size, _, _), _, _) = input_shapes + def init_buffer(shape, dtype): + (_, _, depth) = shape + return np.zeros((batch_size, buffer_length, depth), dtype=dtype) + (_, k, v) = tuple( + init_buffer(shape, dtype) + for (shape, dtype) in zip(input_shapes, input_dtypes) + ) + mask = np.zeros((batch_size, 1, buffer_length)) + index = 0 + state = (k, v, mask, index) + return state + + +def _fast_inference_update_state(inputs, state): + """Updates state of a causal attention layer for fast inference.""" + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use the predict mode.') + for x in inputs: + assert x.shape[1] == 1, ( + 'In predict mode the input sequence must be of length 1.') + # Fast inference: run with only 1 query in each step, storing the sequence + # of keys and values calculated so far in state. + (_, new_k, new_v) = inputs + (ks, vs, mask, index) = state + ks = jax.ops.index_update(ks, jax.ops.index[:, index, :], new_k[:, 0, :]) + vs = jax.ops.index_update(vs, jax.ops.index[:, index, :], new_v[:, 0, :]) + mask = jax.ops.index_update(mask, jax.ops.index[:, :, index], 1) + return (ks, vs, mask, index + 1) + + +class DotProductCausalAttention(BaseCausalAttention): + """A standard (non-memory-efficient) dot product attention implementation.""" + + def __init__(self, dropout=0.0, mode='train'): + super(DotProductCausalAttention, self).__init__() + self._dropout = dropout + self._mode = mode + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + del params + q, k, v = inputs + if self._mode in ('train', 'eval'): + mask_size = q.shape[-2] + # Not all backends define np.tril. However, using onp.tril is inefficient + # in that it creates a large global constant. TODO(kitaev): try to find an + # alternative that works across all backends. + if backend.get_name() == 'jax': + mask = np.tril( + np.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) + else: + mask = onp.tril( + onp.ones((1, mask_size, mask_size), dtype=onp.bool_), k=0) + else: + assert self._mode == 'predict' + state = _fast_inference_update_state(inputs, state) + (k, v, mask, _) = state + + res = DotProductAttention( + q, k, v, mask, dropout=self._dropout, mode=self._mode, rng=rng) + return res, state + + def forward_and_backward(self, inputs, ct, **kwargs): + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use forward_and_backward.') + # Simultaneous forward pass and backprop through the attention mechanism. + def _do_forward(x): # pylint: disable=invalid-name + res, _ = self.forward(x, **kwargs) + return res + output, vjpfun = jax.vjp(_do_forward, inputs) + return output, vjpfun(ct)[0] + + def new_params_and_state(self, input_shapes, input_dtype, rng): + if self._mode in ('train', 'eval'): + return (), () + + assert self._mode == 'predict' + params = () + # Buffer length is hardcoded for now. TODO(pkozakowski): Pass it from the + # model. + max_len = 2048 + state = _fast_inference_init_state(input_shapes, input_dtype, max_len) + return params, state + + +class MemoryEfficientCausalAttention(BaseCausalAttention): + """Memory-efficient dot product attention. + + This layer performs causal attention on long sequences without running out + of memory. Instead of computing dot products for all query-key pairs at once, + it uses a loop to compute attention for a small set of query positions at a + time. The "loop_stride" parameter controls how many query positions are + considered at each iteration of the loop. + + Note that this class does not slice along the batch/head dimension. Looping + over batch elements and heads instead of query positions is also a viable + option. We haven't implemented it, but it may perform well, too. + """ + + def __init__(self, loop_stride, dropout, mode, share_qk=False, hard_k=0): + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use MemoryEfficientCausalAttention.') + super(MemoryEfficientCausalAttention, self).__init__() + self._loop_stride = loop_stride + if dropout >= 1.0: + raise ValueError('Dropout rates must be lower than 1.') + if mode == 'train': + self.dropout = dropout + else: + self.dropout = None + self._share_qk = share_qk + self._hard_k = hard_k + + def forward(self, inputs, params=(), state=(), **kwargs): + del params + output, _ = self.forward_and_backward(inputs, None, **kwargs) + return output, state + + def has_backward(self): + return True + + def backward(self, inputs, output, ct, params=(), state=(), **kwargs): + del output, params, state + _, inputs_ct = self.forward_and_backward(inputs, ct, **kwargs) + return inputs_ct, () + + def make_unit_length(self, x, epsilon=1e-6): + variance = np.mean(x**2, axis=-1, keepdims=True) + norm_inputs = x / np.sqrt(variance + epsilon) + return norm_inputs + + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): + del kwargs + query, key, value = inputs + depth = np.shape(query)[-1] + do_backprop = ct is not None + # jax uses the term cotangent (ct) to refer to gradient signals, and + # vector-Jacobian product (vjp) for back-propagation through a layer. + + def make_mask(N, M, k): # pylint: disable=invalid-name + """Constructs a slice of the causal attention mask. + + Args: + N: number of query positions + M: number of key positions + k: position of the initial query element + + Returns: + N x M mask, where 1.0 indicates that attention is not allowed. + """ + x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) + y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) + mask = jax.lax.lt( + (jax.lax.broadcast_in_dim( + x, shape=(N, M), broadcast_dimensions=(0,)) + k), + jax.lax.broadcast(y, [N])) + mask = jax.lax.convert_element_type(mask, np.float32) + return mask + + def make_self_mask(N, M, k): # pylint: disable=invalid-name + """Masks out elements attending to self. + + Args: + N: number of query positions + M: number of key positions + k: position of the initial query element + + Returns: + N x M mask, where 1.0 indicates that attention is not allowed. + """ + x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) + y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) + mask = jax.lax.eq( + (jax.lax.broadcast_in_dim( + x, shape=(N, M), broadcast_dimensions=(0,)) + k), + jax.lax.broadcast(y, [N])) + mask = jax.lax.convert_element_type(mask, np.float32) + return mask + + def forward_slice(query_slice, q_loop_idx, key, value): # pylint: disable=invalid-name + """Forward pass for a subset of the query vectors.""" + if self._share_qk: + key = self.make_unit_length(key) + + dots = np.matmul( + query_slice, np.swapaxes(key, -1, -2)) / np.sqrt(depth) + + # Causal masking + mask = make_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) + dots = dots - 1e9 * mask + + # Mask out attention to self except when no other targets are available. + if self._share_qk: + self_mask = make_self_mask(dots.shape[-2], dots.shape[-1], q_loop_idx) + dots = dots - 1e5 * self_mask + + # Softmax. + dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) + + if self.dropout is not None and self.dropout > 0.0: + # Dropout is broadcast across the batch+head dimension + dropout_shape = (1, dots.shape[-2], dots.shape[-1]) + slice_rng = jax.random.fold_in(rng, q_loop_idx) + keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) + keep = backend.random.bernoulli(slice_rng, keep_prob, dropout_shape) + multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) + dots = dots * multiplier + + if self._hard_k > 0: + top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. + top_k = jax.lax.stop_gradient(top_k) + dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). + dots = np.maximum(dots, 0) + dots_sum = np.sum(dots, axis=-1, keepdims=True) # Re-normalize. + dots /= dots_sum # Re-normalize. + + out_slice = np.matmul(dots, value) + return out_slice + + def forward_and_vjp_slice(query_slice, q_loop_idx, key, value, ct_slice): # pylint: disable=invalid-name + # Capture q_loop_idx to avoid calculated gradients wrt. it. + def forward_slice_with_q_loop_idx(query_slice, key, value): # pylint: disable=invalid-name + return forward_slice(query_slice, q_loop_idx, key, value) + + output_slice, vjpfun = jax.vjp( + forward_slice_with_q_loop_idx, query_slice, key, value) + return output_slice, vjpfun(ct_slice) + + q_loop_idx = np.zeros((), dtype=np.int32) + q_loop_max = query.shape[-2] + q_loop_stride = self._loop_stride + assert q_loop_max % q_loop_stride == 0, ( + 'Stride must evenly divide the number of query elements.') + + out_accum = np.zeros_like(query) + if do_backprop: + query_ct_accum = np.zeros_like(query) + key_ct_accum = np.zeros_like(key) + value_ct_accum = np.zeros_like(value) + init_vals = ( + q_loop_idx, out_accum, + query_ct_accum, key_ct_accum, value_ct_accum) + else: + init_vals = (q_loop_idx, out_accum) + + def cond_fun(vals): # pylint: disable=invalid-name + q_loop_idx = vals[0] + return jax.lax.lt(q_loop_idx, q_loop_max) + + def body_fun(vals): # pylint: disable=invalid-name + """Compute a slice of the attention mechanism.""" + if do_backprop: + (q_loop_idx, out_accum, + query_ct_accum, key_ct_accum, value_ct_accum) = vals + else: + q_loop_idx, out_accum = vals + + query_slice = jax.lax.dynamic_slice_in_dim( + query, q_loop_idx, q_loop_stride, axis=-2) + + if do_backprop: + ct_slice = jax.lax.dynamic_slice_in_dim( + ct, q_loop_idx, q_loop_stride, axis=-2) + out_slice, partial_ct = forward_and_vjp_slice( + query_slice, q_loop_idx, key, value, ct_slice) + query_ct_accum = jax.lax.dynamic_update_slice_in_dim( + query_ct_accum, partial_ct[0], q_loop_idx, axis=-2) + key_ct_accum = key_ct_accum + partial_ct[1] + value_ct_accum = value_ct_accum + partial_ct[2] + else: + out_slice = forward_slice(query_slice, q_loop_idx, key, value) + + out_accum = jax.lax.dynamic_update_slice_in_dim( + out_accum, out_slice, q_loop_idx, axis=-2) + q_loop_idx = q_loop_idx + q_loop_stride + + if do_backprop: + return (q_loop_idx, out_accum, + query_ct_accum, key_ct_accum, value_ct_accum) + else: + return (q_loop_idx, out_accum) + + final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) + + if not do_backprop: + return final_vals[1], None + else: + return final_vals[1], final_vals[2:] + + +class TimeBinCausalAttention(BaseCausalAttention): + """Causal attention where only nearby chunks of items attend to each other.""" + + def __init__(self, mode, dropout=0.0, bin_length=None, n_bins=None, + share_qk=False): + super(TimeBinCausalAttention, self).__init__() + if (bin_length is None) == (n_bins is None): + raise ValueError('Exactly one of {bin_length, n_bins} must be set.') + self.bin_length = bin_length + self.n_bins = n_bins + self._share_qk = share_qk + if dropout >= 1.0: + raise ValueError('Dropout rates must be lower than 1.') + if mode == 'train': + self.dropout = dropout + else: + self.dropout = 0.0 + self._mode = mode + + def forward_and_backward(self, inputs, ct, **kwargs): + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use forward_and_backward.') + # Simultaneous forward pass and backprop through the attention mechanism. + def _do_forward(x): # pylint: disable=invalid-name + res, _ = self.forward(x, **kwargs) + return res + output, vjpfun = jax.vjp(_do_forward, inputs) + return output, vjpfun(ct)[0] + + def make_unit_length(self, x, epsilon=1e-6): + variance = np.mean(x**2, axis=-1, keepdims=True) + norm_inputs = x / np.sqrt(variance + epsilon) + return norm_inputs + + def _pad_inputs(self, inputs): + seq_len = inputs[0].shape[-2] + n_bins = self.n_bins + bin_length = self.bin_length + if n_bins is None: + n_bins = int(math.ceil(seq_len / bin_length)) + else: + bin_length = int(math.ceil(seq_len / n_bins)) + pad_len = n_bins * bin_length - seq_len + + def pad_input(x): + pad_widths = [(0, 0)] * len(x.shape) + pad_widths[-2] = (0, pad_len) # Padding on axis=-2 + return np.pad(x, pad_widths, mode='constant', + constant_values=x.dtype.type(0)) + + padded_inputs = tuple(map(pad_input, inputs)) + return (padded_inputs, seq_len, n_bins) + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + del params, kwargs + if self._mode in ('train', 'eval'): + output = self._forward_train_eval(inputs, rng) + return (output, state) + else: + assert self._mode == 'predict' + return self._forward_predict(inputs, state, rng) + + def _forward_train_eval(self, inputs, rng): + (inputs, original_len, n_bins) = self._pad_inputs(inputs) + q, k, v = inputs + seqlen = q.shape[-2] + # q/k/v are n_batch*n_heads, seqlen, d_head + # Time indices for causal masking. + t = jax.lax.tie_in(q, np.arange(seqlen)) + + # Split off a "bin" axis for chunks of consecutive items. + bq_t = np.reshape(t, (n_bins, -1)) + bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1])) + if self._share_qk: + bk = self.make_unit_length(bq) + else: + bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1])) + bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1])) + + # Allow each chunk to attend within itself, and also one chunk back. + def look_one_back(x): + # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis. + if len(x.shape) == 2: + x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) + return np.concatenate([x, x_extra], axis=1) + else: + assert len(x.shape) == 4 + x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1) + return np.concatenate([x, x_extra], axis=2) + + bkv_t = look_one_back(bq_t) + bk = look_one_back(bk) + bv = look_one_back(bv) + + # Dot-product attention. + dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) + + # Causal masking based on the time indices. + mask = jax.lax.convert_element_type( + jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]), + np.float32) + dots = dots - 1e9 * mask + + # Mask out attention to self except when no other targets are available. + if self._share_qk: + self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) + self_mask = jax.lax.tie_in(dots, self_mask) + dots = dots - 1e5 * self_mask + + if self.dropout > 0.0: + # Dropout is broadcast across the batch+head dimension + dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1]) + keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) + keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) + multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) + dots = dots * multiplier + + # Softmax. + dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) + bo = np.matmul(dots, bv) + + output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1])) + assert output.shape == v.shape + return output[..., :original_len, :] + + def _forward_predict(self, inputs, state, rng): + state = _fast_inference_update_state(inputs, state) + + (q, _, _) = inputs + (ks, vs, mask, index) = state + output = DotProductAttention( + q, ks, vs, mask, dropout=self.dropout, mode=self._mode, rng=rng + ) + + def roll_state(state): + """Rolls the buffers backward to make space for new data.""" + (ks, vs, mask, index) = state + # Move the second bin into the first one's place in both buffers. + def roll_buffer(buf): + return jax.ops.index_update( + buf, + jax.ops.index[:, :self.bin_length, :], + buf[:, self.bin_length:, :], + ) + (ks, vs) = map(roll_buffer, (ks, vs)) + # Zero out the second bin in the mask. + mask = jax.ops.index_update( + mask, jax.ops.index[:, :, self.bin_length:], 0 + ) + # Update the index to match the rolled buffers. + index -= self.bin_length + return (ks, vs, mask, index) + + # Once we get to the end of the buffer, move the second bin back to make + # space for new data: [ bin_i bin_{i+1} | ] -> [ bin_{i+1} | bin_{i+1} ], + # where | is where index points at in the buffer. + state = jax.lax.cond( + pred=(index == 2 * self.bin_length), + true_operand=state, + true_fun=roll_state, + false_operand=state, + false_fun=(lambda x: x), + ) + return (output, state) + + def new_params_and_state(self, input_shapes, input_dtype, rng): + if self._mode in ('train', 'eval'): + return (), () + + assert self._mode == 'predict' + assert self.bin_length is not None, ( + 'For fast inference, TimeBinCausalAttention must be parameterized by ' + 'bin_length.' + ) + params = () + state = _fast_inference_init_state( + input_shapes, input_dtype, 2 * self.bin_length + ) + return params, state + + +class LSHCausalAttention(BaseCausalAttention): + """Causal attention based on locality-sensitive hashing.""" + + def __init__(self, dropout, mode, n_bins=64, n_hashes=1, n_buckets=64, + one_rng=False, allow_duplicate_attention=False, + attend_across_buckets=False, hard_k=0, factorize_hash=False, + rehash_each_round=True, drop_for_hash_rate=0.0): + del dropout + self._mode = mode + super(LSHCausalAttention, self).__init__() + assert n_buckets >= n_bins, 'This setting is not recommended: too few bins.' + assert rehash_each_round or allow_duplicate_attention, ( + 'The setting {allow_duplicate_attention=False, rehash_each_round=False}' + ' is not implemented.') + self.n_bins = n_bins + self.n_hashes = n_hashes + self.n_buckets = n_buckets + self._drop_for_hash_rate = drop_for_hash_rate + self._one_rng = one_rng + self._factorize_hash = factorize_hash + self._prng = None + if one_rng: + seed = random.randint(0, 2**31 - 1) + self._prng = backend.random.get_prng(seed) + + self._allow_duplicate_attention = allow_duplicate_attention + self._attend_across_buckets = attend_across_buckets + self._hard_k = hard_k + self._rehash_each_round = rehash_each_round + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + del params, kwargs + output, _ = self.batch_call_and_or_grad(inputs[0], inputs[2], rng=rng) + return output, state + + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): + del kwargs + output, (qk_ct, v_ct) = self.batch_call_and_or_grad( + inputs[0], inputs[2], ct=ct, rng=rng) + return output, (qk_ct, np.zeros_like(inputs[1]), v_ct) + + def has_backward(self): + return True + + def backward(self, inputs, output, ct, params=(), state=(), rng=None, + **kwargs): + del output, params, state + _, (qk_ct, v_ct) = self.batch_call_and_or_grad( + inputs[0], inputs[2], return_output=False, ct=ct, rng=rng) + inputs_ct = (qk_ct, np.zeros_like(inputs[1]), v_ct) + return inputs_ct, () + + def batch_call_and_or_grad(self, qk, v, ct=None, return_output=True, + rng=None): + assert return_output or ct is not None, 'No work to perform!' + # pylint: disable=protected-access + stash_buckets = (return_output and ct is None + and base.Layer._STASH_IN is not None) + if return_output and ct is not None and base.Layer._STASH_OUT is not None: + buckets = base.Layer._STASH_OUT.pop(self) + else: + buckets = None + # pylint: enable=protected-access + + # The approach here is to perform attention for one batch element and head + # at a time. Note that there is absolutely no interaction across examples or + # heads: this layer has no parameters, and hashing patterns are also + # different across examples/heads. As a result, batching doesn't give any + # performance gains except in the case of accelerator under-utilization. We + # assume that hash-based attention will be applied primarily to long + # sequences, where unbatched attention for a single head has sufficient + # computation to fill up the accelerator. + + batch_loop_idx = np.zeros((), dtype=np.int32) + batch_loop_max = qk.shape[0] + + init_vals = (batch_loop_idx,) + if return_output: + out_accum = np.zeros_like(qk) + init_vals = init_vals + (out_accum,) + if stash_buckets: + buckets_accum = np.zeros( + [qk.shape[0], self.n_hashes * qk.shape[1]], dtype=np.int32) + init_vals = init_vals + (buckets_accum,) + if ct is not None: + qk_ct_accum = np.zeros_like(qk) + v_ct_accum = np.zeros_like(v) + init_vals = init_vals + (qk_ct_accum, v_ct_accum) + + def cond_fun(vals): + batch_loop_idx = vals[0] + return jax.lax.lt(batch_loop_idx, batch_loop_max) + + def body_fun(vals): + """Performs attention for a single batch element and head.""" + batch_loop_idx = vals[0] + if self._prng is None: + hash_rng = jax.random.fold_in(rng, batch_loop_idx) + else: + # TODO(kitaev): Maybe use the same RNG across examples (but not heads)? + hash_rng = jax.random.fold_in(self._prng, batch_loop_idx) + qk_slice = jax.lax.dynamic_index_in_dim( + qk, batch_loop_idx, axis=0, keepdims=False) + v_slice = jax.lax.dynamic_index_in_dim( + v, batch_loop_idx, axis=0, keepdims=False) + + if buckets is None: + buckets_slice = self.hash_vectors(qk_slice, rng=hash_rng) + else: + buckets_slice = jax.lax.dynamic_index_in_dim( + buckets, batch_loop_idx, axis=0, keepdims=False) + + if ct is None: + out_slice = self.single_call( + qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) + else: + def _do_single_call(qk_slice, v_slice): + return self.single_call( + qk_slice, v_slice, buckets_slice, hash_rng=hash_rng) + ct_slice = jax.lax.dynamic_index_in_dim( + ct, batch_loop_idx, axis=0, keepdims=False) + out_slice, vjpfun = jax.vjp(_do_single_call, qk_slice, v_slice) + qk_ct_slice, v_ct_slice = vjpfun(ct_slice) + + new_vals = (batch_loop_idx + 1,) + if return_output: + out_accum = vals[1] + out_accum = jax.lax.dynamic_update_index_in_dim( + out_accum, out_slice, batch_loop_idx, axis=0) + new_vals = new_vals + (out_accum,) + if stash_buckets: + buckets_accum = vals[2] + buckets_accum = jax.lax.dynamic_update_index_in_dim( + buckets_accum, buckets_slice, batch_loop_idx, axis=0) + new_vals = new_vals + (buckets_accum,) + if ct is not None: + qk_ct_accum, v_ct_accum = vals[-2:] + qk_ct_accum = jax.lax.dynamic_update_index_in_dim( + qk_ct_accum, qk_ct_slice, batch_loop_idx, axis=0) + v_ct_accum = jax.lax.dynamic_update_index_in_dim( + v_ct_accum, v_ct_slice, batch_loop_idx, axis=0) + new_vals = new_vals + (qk_ct_accum, v_ct_accum) + + return new_vals + + final_vals = jax.lax.while_loop(cond_fun, body_fun, init_vals) + + if return_output: + out = final_vals[1] + else: + out = None + + if stash_buckets: + base.Layer._STASH_IN[self] = final_vals[2] # pylint: disable=protected-access + + if ct is not None: + input_ct = final_vals[-2:] + else: + input_ct = None + + return out, input_ct + + def make_unit_length(self, x, epsilon=1e-6): + variance = np.mean(x**2, axis=-1, keepdims=True) + norm_inputs = x / np.sqrt(variance + epsilon) + return norm_inputs + + def drop_for_hash(self, x, rng): + rate = self._drop_for_hash_rate + if self._mode == 'train' and rate > 0.0: + keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) + return np.where(keep, x / (1.0 - rate), np.zeros_like(x)) + return x + + def hash_vectors(self, vecs, rng): + # See https://arxiv.org/pdf/1509.02897.pdf + # We sample a different random rotation for each round of hashing to + # decrease the probability of hash misses. + assert self.n_buckets % 2 == 0 + + # If we factorize the hash, find a factor dividing n_buckets nicely. + rot_size, factor_list = self.n_buckets, [self.n_buckets] + if self._factorize_hash: + # If we are given a list of factors, verify it and use later. + if isinstance(self._factorize_hash, list): + rot_size, product = 0, 1 + factor_list = self._factorize_hash + for factor in factor_list: + assert factor % 2 == 0 + product *= factor + rot_size += factor + assert product == self.n_buckets + else: # Find one factor if just set to True. + # We want to represent self.n_buckets = factor * rest so that + # (1) both factor and rest are even, and (2) factor + rest is minimal. + # To compute this we start from factor = sqrt(n_buckets) and go down + # with it until we find one that satisfies the constraints above. + factor = int(math.sqrt(self.n_buckets)) + while factor > 0 and not ( + self.n_buckets % factor == 0 and + factor % 2 == 0 and + (self.n_buckets // factor) % 2 == 0): + factor -= 1 + if factor > 2: # Factor of 2 does not warrant the effort. + rot_size = factor + (self.n_buckets // factor) + factor_list = [factor, self.n_buckets // factor] + + random_rotations_shape = ( + vecs.shape[-1], + self.n_hashes if self._rehash_each_round else 1, + rot_size // 2) + + rng = jax.lax.tie_in(vecs, rng) + rng, subrng = backend.random.split(rng) + random_rotations = jax.random.normal( + rng, random_rotations_shape).astype('float32') + # TODO(lukaszkaiser): the dropout mask will be used for all rounds of + # hashing, so it's shared between them. Check if that's what we want. + dropped_vecs = self.drop_for_hash(vecs, subrng) + rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) + + if self._rehash_each_round: + if self._factorize_hash and len(factor_list) > 1: + # We factorized self.n_buckets as the product of factor_list. + # Get the buckets for them and combine. + buckets, cur_sum, cur_product = None, 0, 1 + for factor in factor_list: + rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] + cur_sum += factor // 2 + rv = np.concatenate([rv, -rv], axis=-1) + if buckets is None: + buckets = np.argmax(rv, axis=-1) + else: + buckets += cur_product * np.argmax(rv, axis=-1) + cur_product *= factor + else: + rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) + buckets = np.argmax(rotated_vecs, axis=-1) + # buckets is now (self.n_hashes, seqlen). Next we add offsets so that + # bucket numbers from different hashing rounds don't overlap. + offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) + offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) + buckets = np.reshape(buckets + offsets, (-1,)) + else: + assert not self._factorize_hash + rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) + # In this configuration, we map each item to the top self.n_hashes buckets + rotated_vecs = np.squeeze(rotated_vecs, 0) + bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) + bucket_range = np.reshape(bucket_range, (1, -1)) + bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) + + _, buckets = jax.lax.sort_key_val( + rotated_vecs, bucket_range, dimension=-1) + buckets = buckets[:, -self.n_hashes:] + buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,)) + + return buckets + + def single_call(self, qk, v, buckets, hash_rng=None): + # We use the same vector as both a query and a key. + seqlen = qk.shape[-2] + assert int(buckets.shape[0]) == self.n_hashes * seqlen + + ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) + buckets_and_t = seqlen * buckets + (ticker % seqlen) + buckets_and_t = jax.lax.stop_gradient(buckets_and_t) + + # Hash-based sort ("s" at the start of variable names means "sorted") + sbuckets_and_t, sticker = jax.lax.sort_key_val( + buckets_and_t, ticker, dimension=-1) + _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) + sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) + sticker = jax.lax.stop_gradient(sticker) + undo_sort = jax.lax.stop_gradient(undo_sort) + + st = (sticker % seqlen) + sqk = np.take(qk, st, axis=0) + sv = np.take(v, st, axis=0) + + # Split off a "bin" axis so that attention only occurs within chunks. + bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) + bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) + bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) + bq_buckets = bkv_buckets = np.reshape( + sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) + + # Hashing operates on unit-length vectors. Unnormalized query vectors are + # fine because they effectively provide a learnable temperature for the + # attention softmax, but normalizing keys is needed so that similarity for + # the purposes of attention correctly corresponds to hash locality. + bq = bqk + bk = self.make_unit_length(bqk) + + # Allow each chunk to attend within itself, and also one chunk back. Chunk + # boundaries might occur in the middle of a sequence of items from the + # same bucket, so this increases the chances of attending to relevant items. + # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. + def look_one_back(x): + if len(x.shape) == 2: + x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) + else: + x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) + return np.concatenate([x, x_extra], axis=1) + + bk = look_one_back(bk) + bv = look_one_back(bv) + bkv_t = look_one_back(bkv_t) + bkv_buckets = look_one_back(bkv_buckets) + + # Dot-product attention. + dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) + + # Causal masking + mask = jax.lax.convert_element_type( + jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), + np.float32) + dots = dots - 1e9 * mask + + # Mask out attention to self except when no other targets are available. + self_mask = jax.lax.convert_element_type( + jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), + np.float32) + dots = dots - 1e5 * self_mask + + # Mask out attention to other hash buckets. + if not self._attend_across_buckets: + bucket_mask = jax.lax.convert_element_type( + jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), + np.float32) + dots = dots - 1e7 * bucket_mask + + # Don't double-count query-key pairs across multiple rounds of hashing. + # There are two possible strategies here. (1) The default is to count how + # many times a query-key pair is repeated, and to lower its log-prob + # correspondingly at each repetition. (2) When hard_k is set, the code + # instead masks all but the first occurence of each query-key pair. + # TODO(kitaev): is one strategy faster or more numerically stable? + if not self._allow_duplicate_attention: + locs1 = undo_sort // bq_t.shape[-1] + locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) + if not self._attend_across_buckets: + locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 + locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 + locs = np.moveaxis(np.concatenate([ + np.reshape(locs1, (self.n_hashes, seqlen)), + np.reshape(locs2, (self.n_hashes, seqlen)), + ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) + slocs = np.take(locs, st, axis=0) + b_locs = np.reshape( + slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) + # Queries always use the primary location (based on locs1). + b_locs1 = b_locs[:, :, None, :self.n_hashes] + if self._hard_k > 0: + range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) + nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) + nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use + nouse_locs = np.reshape( + np.broadcast_to(nouse_locs[:, None, :], + (self.n_hashes, self.n_bins, self.n_hashes)), + (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) + b_locs1 = b_locs1 * nouse_locs + bq_locs = np.broadcast_to( + b_locs1, + b_locs.shape[:2] + (2, self.n_hashes)) + bq_locs = np.reshape(bq_locs, b_locs.shape) + bkv_locs = look_one_back(b_locs) + + dup_counts = np.sum( + jax.lax.convert_element_type( + jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), + np.float32), + axis=-1) + assert dup_counts.shape == dots.shape + if self._hard_k > 0: + dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) + else: + dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) + + # Each query only attends to the top k most relevant keys. + if self._hard_k > 0: + b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. + b_top_dots = jax.lax.stop_gradient(b_top_dots) + s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) + top_dots = np.take(s_top_dots, undo_sort, axis=0) + + merged_top_dots = np.moveaxis( + np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) + merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) + + dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] + # It's possible to compute the partition function at this point, but right + # now this codepath isn't set up for backprop, and there might also be + # issues computing it this way if two dot-products are exactly equal. + + sdots_thresh = dots_thresh[st] + bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) + bdots_thresh = jax.lax.stop_gradient(bdots_thresh) + + top_k_mask = jax.lax.convert_element_type( + dots < bdots_thresh[..., None], np.float32) + dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) + + # Softmax. + dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) + dots = np.exp(dots - dots_logsumexp) + + bo = np.matmul(dots, bv) + so = np.reshape(bo, (-1, bo.shape[-1])) + slogits = np.reshape(dots_logsumexp, (-1,)) + + def unsort_for_output_impl(so, slogits): + o = np.take(so, undo_sort, axis=0) + # Sorting is considerably faster than gather, but first we need to get the + # XLA compiler to abandon the idea of fusing this sort with the input sort + # (which introduces a computation cycle and leads to a crash). + # TODO(kitaev): remove "sticker_" variable if XLA is fixed. + sticker_ = sticker + jax.lax.convert_element_type( + slogits[0] > 0, sticker.dtype) + _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) + return o, logits + + def unsort_for_output_vjp(so, slogits): + """Custom gradient for unsort_for_output.""" + so = jax.lax.stop_gradient(so) + slogits = jax.lax.stop_gradient(slogits) + o, logits = unsort_for_output_impl(so, slogits) + def vjpfun(o_logits_grads): + so_grad = np.take(o_logits_grads[0], sticker, axis=0) + # TODO(kitaev): this exists to match the forward pass, but I'm not sure + # if it's actually required. + buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( + o_logits_grads[1][0] > 0, buckets_and_t.dtype) + _, slogits_grad = jax.lax.sort_key_val( + buckets_and_t_, o_logits_grads[1], dimension=-1) + return (so_grad, slogits_grad) + return (o, logits), vjpfun + + unsort_for_output = jax.custom_transforms(unsort_for_output_impl) + jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) + o, logits = unsort_for_output_impl(so, slogits) + + if self.n_hashes == 1: + out = o + else: + o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) + logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) + probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) + out = np.sum(o * probs, axis=0) + + assert out.shape == v.shape + return out + + +def CausalAttention(d_feature, n_heads=1, + d_attention_key=None, d_attention_value=None, + attention_type=DotProductCausalAttention, + share_qk=False, mode='train'): + """Transformer-style multi-headed causal attention. + + Args: + d_feature: int: dimensionality of feature embedding + n_heads: int: number of attention heads + d_attention_key: int: depth of key vector for each attention head + (default is d_feature // n_heads) + d_attention_value: int: depth of value vector for each attention head + (default is d_feature // n_heads) + attention_type: subclass of BaseCausalAttention: attention class to use + share_qk: bool, whether to share queries and keys + mode: str: 'train' or 'eval' + + Returns: + Multi-headed self-attention result. + """ + if d_attention_key is None: + assert d_feature % n_heads == 0 + d_attention_key = d_feature // n_heads + if d_attention_value is None: + assert d_feature % n_heads == 0 + d_attention_value = d_feature // n_heads + + if share_qk: + pre_attention = [ + cb.Dup(), + cb.Parallel( + ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), + ), + cb.Dup(), + ] + else: + pre_attention = [ + cb.Dup(), cb.Dup(), + cb.Parallel( + ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), + ), + ] + + return pre_attention + [ + attention_type(mode=mode), + ComputeAttentionOutput(n_heads=n_heads, d_model=d_feature), + ] diff --git a/trax/layers/attention_test.py b/trax/layers/attention_test.py new file mode 100644 index 000000000..f157972c0 --- /dev/null +++ b/trax/layers/attention_test.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.layers.attention.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as onp +from tensorflow import test +from trax.layers import attention +from trax.layers import base + + +class AttentionTest(test.TestCase): + + def test_shift_right(self): + # Test shifts right on axis=1 + layer = attention.ShiftRight() + input_np = onp.arange(2*3*3).reshape(2, 3, 3) + output_np = layer(input_np) + self.assertEqual(input_np.shape, output_np.shape) + self.assertAllEqual(onp.array([[[0, 0, 0], + [0, 1, 2], + [3, 4, 5]], + + [[0, 0, 0], + [9, 10, 11], + [12, 13, 14]]]), + output_np) + + def test_shift_right_float(self): + layer = attention.ShiftRight() + input_np = onp.arange(2*3*3).reshape(2, 3, 3).astype(onp.float32) + # Test on a float array. + input_np = input_np.astype(onp.float32) + input_np /= 2.0 + self.assertEqual(input_np.dtype, onp.float32) + + output_np = layer(input_np) + self.assertEqual(input_np.shape, output_np.shape) + self.assertEqual(output_np.dtype, onp.float32) + + self.assertAllEqual(onp.array([[[0., 0., 0.], + [0., 0.5, 1.], + [1.5, 2., 2.5]], + + [[0., 0., 0.], + [4.5, 5., 5.5], + [6., 6.5, 7.]]]), + output_np) + + def test_merged_hashed_causal_attention(self): + qkv_shape = (3, 32, 8) + input_shape = (qkv_shape, qkv_shape, qkv_shape) + layer = attention.MemoryEfficientCausalAttention( + loop_stride=16, dropout=0.1, mode='train') + final_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual((3, 32, 8), final_shape) + + def test_time_bin_causal_attention_bin_length(self): + qkv_shape = (3, 57, 8) + input_shape = (qkv_shape, qkv_shape, qkv_shape) + layer = attention.TimeBinCausalAttention( + bin_length=16, dropout=0.1, mode='train') + final_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual((3, 57, 8), final_shape) + + def test_time_bin_causal_attention_n_bins(self): + qkv_shape = (3, 57, 8) + input_shape = (qkv_shape, qkv_shape, qkv_shape) + layer = attention.TimeBinCausalAttention( + n_bins=4, dropout=0.1, mode='train') + final_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual((3, 57, 8), final_shape) + + def test_time_bin_and_dot_product_causal_attention_are_consistent(self): + dot_product_layer = attention.DotProductCausalAttention( + dropout=0.0, mode='train') + time_bin_layer = attention.TimeBinCausalAttention( + bin_length=4, dropout=0.0, mode='train') + + # Exactly 2 bins. + input_shape = (3, 8, 8) + inputs = [onp.random.uniform(size=input_shape) for _ in range(3)] + + dot_product_output = dot_product_layer(inputs) + time_bin_output = time_bin_layer(inputs) + onp.testing.assert_array_almost_equal(dot_product_output, time_bin_output) + + +if __name__ == '__main__': + test.main() diff --git a/trax/layers/base.py b/trax/layers/base.py new file mode 100644 index 000000000..289bb700e --- /dev/null +++ b/trax/layers/base.py @@ -0,0 +1,664 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base layer class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect +import traceback + +import jax + +import numpy as onp +from trax import backend +from trax.backend import nested_map +from trax.backend import ShapeType + + +class Layer(object): + """Base class for composable layers in a deep learning network. + + Layers are the basic building blocks for deep learning models. A Trax layer + computes a function from zero or more inputs to zero or more outputs, + optionally using trainable parameters (common) and non-parameter state (not + common). Authors of new layer subclasses typically override at most two + methods of the base `Layer` class: + + forward(inputs, params=(), state=(), **kwargs): + Computes this layer's output as part of a forward pass through the model. + + new_params_and_state(self, input_shape, input_dtype, rng): + Returns a (params, state) pair suitable for initializing this layer. + + A small subset of layer types are combinators -- they organize the computation + of their sublayers, e.g., applying their sublayers in series or in parallel. + + All layers have the following properties, with default values implemented + in the base `Layer` class: + + - n_inputs: int (default 1) + - n_outputs: int (default 1) + - params: tuple (default empty -- the layer has no parameters) + - state: tuple (default empty -- the layer has no non-parameter state) + - sublayers: tuple (default empty -- the layer has no sublayers) + + The inputs to a layer are tensors, packaged according to how many there are: + + - n_inputs = 0: an empty tuple () + - n_inputs = 1: one tensor (NOT wrapped in a tuple) + - n_inputs > 1: a tuple of tensors + + (The special treatment of the single-input case is meant to simplify the + work of layer writers; this design choice may be revisited in the future.) + + The outputs from a layer are also tensors, packaged the same as layer inputs: + + - n_outputs = 0: an empty tuple () + - n_outputs = 1: the tensor (NOT wrapped in a tuple) + - n_outputs > 1: a tuple of tensors + + The Trax runtime maintains a data stack with which layer calls are composed. + For more complex data network architectures, possibly involving multiple data + flows, one can view each layer as a function from stack state to stack state, + where the function's inputs are a slice from the stack, and the function's + outputs are spliced back into the stack. + """ + + def __init__(self, n_inputs=1, n_outputs=1): + """Creates a partially initialized, unconnected layer instance. + + Args: + n_inputs: Number of inputs expected by this layer. + n_outputs: Number of outputs promised by this layer. + """ + self._n_inputs = n_inputs + self._n_outputs = n_outputs + self._sublayers = () # Default is no sublayers. + self._params = () # cached parameters + self._state = () + self._caller = _find_frame(inspect.stack()) # for custom error messages + self._init_finished = False + + def __repr__(self): + class_str = self.__class__.__name__ + fields_str = 'in={},out={}'.format(self.n_inputs, self.n_outputs) + objs = self.sublayers + if objs: + objs_str = ', '.join(str(x) for x in objs) + return '{}{{{},sublayers=[{}]}}'.format(class_str, fields_str, objs_str) + else: + return '{}{{{}}}'.format(class_str, fields_str) + + def forward(self, inputs, params=(), state=(), **kwargs): + """Computes this layer's output as part of a forward pass through the model. + + Authors of new Layer subclasses should override this method to define the + forward computation that their layer performs. + + Args: + inputs: Input tensors, matching the number (n_inputs) expected by this + layer. Specifically: + - n_inputs = 0: an empty tuple () + - n_inputs = 1: a tensor (NOT wrapped in a tuple) + - n_inputs > 1: a tuple of tensors, with n_inputs items + params: A tuple of trainable parameters, with one element for this layer + if this layer has no sublayers, or one for each sublayer if this + layer has sublayers. If a layer (or sublayer) has no trainable + parameters, the corresponding params element is an empty tuple. + state: Layer-specific non-parameter state that can update between batches. + **kwargs: Often empty; main current use is to carry a PRNG key for random + number generation, using the keyword 'rng'. + + Returns: + Tensors, matching the number (n_outputs) promised by this layer. + Specifically: + - n_outputs = 0: an empty tuple + - n_outputs = 1: one tensor (NOT wrapped in a tuple) + - n_outputs > 1: a tuple of tensors, with n_outputs items + """ + raise NotImplementedError + + def new_params_and_state(self, input_shape, input_dtype, rng): + """Returns a (params, state) pair suitable for initializing this layer. + + Authors of new Layer subclasses should override this method if their layer + uses trainable parameters or has non-parameter state that gets updated + between batches. The default implementation works for layers that have + no parameters or state. + + Args: + input_shape: A tuple representing a shape (if this layer takes one input) + or a tuple of shapes (if this layer takes more than one input). + For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)). + input_dtype: Numpy dtype(s) for each of the inputs. + rng: A PRNG key for random number generation. + """ + del input_shape, input_dtype, rng + return (), () + + @property + def n_inputs(self): + """Returns how many tensors this layer expects as input.""" + return self._n_inputs + + @property + def n_outputs(self): + """Returns how many tensors this layer promises as output.""" + return self._n_outputs + + @property + def sublayers(self): + """Returns a tuple containing this layer's sublayers; may be empty.""" + return self._sublayers + + @property + def params(self): + """Returns a tuple containing this layer's parameters; may be empty.""" + return self._params + + @params.setter + def params(self, params): + self._params = params + + @property + def state(self): + """Returns a tuple containing this layer's state; may be empty.""" + return self._state + + @state.setter + def state(self, state): + self._state = state + + @property + def has_backward(self): + """Returns True if this layer provides its own (custom) backward pass code. + + A layer subclass that provides custom backward pass code (for custom + gradients) must override this method to return True. + """ + return False + + def backward(self, inputs, output, grad, params, state, **kwargs): + """Custom backward pass to propagate gradients in a custom way. + + Args: + inputs: Input tensors; can be a (possibly nested) tuple. + output: The result of running this layer on inputs. + grad: gradient signal (called cotangent in jax) computed based on + subsequent layers. The structure and shape must match output. + params: layer parameters + state: start state. + **kwargs: kwargs for the layer + + Returns: + The custom gradient signal for the input. Note that we need to return + a gradient for each argument of forward, so it will usually be a tuple + of signals: the gradient for inputs and parameters. + """ + raise NotImplementedError + + # End of subclassing interface, all functions below are internal. + + def pseudo_forward(self, pseudo_inputs, params, state): + """Computes shapes and types this layer would produce for the given inputs. + + Args: + pseudo_inputs: A ShapeType instance (input data minus the actual values) + or a tuple of ShapeType instances, following the same conventions as + Layer.forward's input arg. + params: Parameters for this layer. + state: start state. + + Returns: + A tuple of (output, state). + + The output part of the tuple is a ShapeType instance representing the + shape and type of the output (if this layer has one output) or a tuple + of ShapeType instances (if this layer has more than one output). + """ + try: + # Beware: using an actual RNG (as opposed to this ShapeType stub) would + # cause a large number of dropout masks to be computed and permanently + # stored in global memory. + rng = ShapeType(shape=(2,), dtype=onp.uint32) + def call_on_input(x, params, state, rng): + return self.forward(x, params=params, state=state, rng=rng) + params_shapes = nested_map( + params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype)) + s = backend.eval_on_shapes(call_on_input)(pseudo_inputs, + params_shapes, state, rng) + return s + except Exception: + name, trace = self.__class__.__name__, _short_traceback(skip=3) + raise LayerError(name, 'pseudo_forward', self._caller, pseudo_inputs, + None, trace) + + def initialize_once(self, input_shapes, input_dtype, rng): + """Initializes this layer and its sublayers recursively. + + This method is designed to initialize each layer instance once, even if the + same layer instance occurs in multiple places in the network. This enables + weight sharing to be implemented as layer sharing. + + Args: + input_shapes: A tuple representing a shape (if this layer takes one input) + or a tuple of shapes (if this layer takes more than one input). + For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)). + input_dtype: Numpy dtype(s) for each of the inputs. + rng: A PRNG key for random number generation. + + Returns: + A (params, state) tuple, in which params contains newly created parameters + on the first call and () on all subsequent calls. + """ + try: + # Initialize params once; store them for use when this layer is called. + # Needs to call new_params_and_state regardless of _init_finished because + # state also needs to be initialized. After jitting, graph pruning should + # be able to remove unnecessary computation. + # TODO(lukaszkaiser): Revisit this decision and see whether layers sharing + # params should also share states. + params, state = self.new_params_and_state(input_shapes, input_dtype, rng) + if not self._init_finished: + self._init_finished = True + self._params = params + self._state = state + else: + params = () + return (params, state) + except Exception: + name, trace = self.__class__.__name__, _short_traceback(skip=3) + raise LayerError(name, 'initialize_once', self._caller, input_shapes, + input_dtype, trace) + + # XXX(kitaev): + _STASH_IN = None + _STASH_OUT = None + + def __call__(self, x, **kwargs): + """Makes Layer instances callable; for use in tests or interactive settings. + + This convenience method helps library users play with, test, or otherwise + probe the behavior of layers outside of a full training environment. It + presents the layer as callable function from inputs to outputs, with the + option of manually specifying parameters and non-parameter state per + individual call. For convenience, parameters and non-parameter state are + cached per layer instance, starting from default values of () and (), and + acquiring non-empty values either by initialization or from values + explicitly provided via the params and state keyword arguments. + + Args: + x: 0 or more input tensors, formatted the same as the inputs to + Layer.forward. + **kwargs: Additional keyword arguments if needed/desired for this layer. + Three possible keyword arguments are especially relevant: + - params=... will override any cached params values + - state=... will override any cached state values + - rng=... will supply a PRNG key for use by the layer + + Returns: + 0 or more output tensors, formatted the same as the outputs from + Layer.forward. + """ + params = kwargs.pop('params', self.params) + state = kwargs.pop('state', self.state) + outputs, _ = self.apply_forward(x, params=params, state=state, **kwargs) + return outputs + + def apply_forward(self, x, params=(), state=(), **kwargs): + """Applies this layer as part of a forward pass; an internal system method. + + This method is reserved for handling plumbing and other internal affairs + as needed by the overall library. Trax library users should use or override + the `forward` method instead. + + Args: + x: See Layer.forward inputs. + params: See Layer.forward. + state: See Layer.forward. + **kwargs: See Layer.forward. + + Returns: + See Layer.forward. + """ + try: + # If params are nothing, we may be reusing this layer. + # Use the cached parameters to calculate the value. + # Note: to make sure jit tracers can decide this branch in python we + # use "params is ()" instead of, e.g., "not params" or "params == ()". + if params is (): # pylint: disable=literal-comparison + params = self._params + else: + # In this case, we're called for the first time: cache parameters. + self._params = params + + if not self.has_backward or Layer._STASH_IN is not None: + outputs, s = self.forward(x, params=params, state=state, **kwargs) + else: + outputs, s = self._do_custom_gradients(x, params, state, **kwargs) + self._state = s + return outputs, s + + except Exception: + name, trace = self.__class__.__name__, _short_traceback() + raise LayerError(name, 'apply_forward', self._caller, + shapes(x), None, trace) + + def _do_custom_gradients(self, x, params, state, **kwargs): + """Calls this layer for a forward pass, but with custom gradients.""" + assert backend.get_name() == 'jax', ( + 'Custom gradients are only supported in JAX for now.') + + # TODO(wangpeng): JAX doesn't support custom grads for functions with + # auxiliary output yet (https://github.com/google/jax/issues/844). Will + # remove the constraints on state below when this feature is added to + # JAX. + + assert not jax.tree_util.tree_leaves(state), ( + 'Custom gradients require trivial start state. Got %s' % str(state)) + + def check_end_state(output_state): + output, state = output_state + assert not jax.tree_util.tree_leaves(state), ( + 'Custom gradients require trivial end state. Got %s' % str(state)) + return output + + # See this link for how custom transformations are defined in JAX: + # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms + # Note that we capture the kwargs and don't calculate gradients wrt. them. + @jax.custom_transforms + def _do_forward(y, params): + return check_end_state(self.forward(y, params=params, state=state, + **kwargs)) + + # This is the custom gradient (vector-jacobian product in JAX) function. + # For the exact specification of this custom transformation see this link: + # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all + def do_forward_vjp(y, params): + """Custom gradient (vjp) function.""" + stash = None + if Layer._STASH_IN is None: + Layer._STASH_IN = stash = {} + output = check_end_state(self.forward(y, params=params, state=state, + **kwargs)) + if stash is not None: + Layer._STASH_IN = None + def vjpfun(grad): + assert Layer._STASH_OUT is None + Layer._STASH_OUT = stash + res = self.backward(y, output, grad, params, state, **kwargs) + Layer._STASH_OUT = None + return res + return output, vjpfun + + jax.defvjp_all(_do_forward, do_forward_vjp) + return _do_forward(x, params), state + + +class LayerError(Exception): + """Exception raised in the layer stack. + + Attributes: + message: the message corresponding to this exception. + """ + + def __init__(self, layer_name, function_name, caller, + input_shapes, input_types, traceback_string): + self._layer_name = layer_name + self._function_name = function_name + self._caller = caller # Python inspect object with init caller info. + self._traceback = traceback_string + self._input_shapes = input_shapes + self._input_types = input_types + super(LayerError, self).__init__(self.message) + + @property + def message(self): + """Create error message.""" + prefix = 'Exception passing through layer ' + prefix += '%s (in %s):\n' % (self._layer_name, self._function_name) + short_path = '[...]/' + '/'.join(self._caller.filename.split('/')[-3:]) + caller = ' layer created in file %s, line %d\n' % (short_path, + self._caller.lineno) + shapes_str = ' layer input shapes: %s\n\n' % str(self._input_shapes) + if self._input_types is not None: + types_str = ' layer input types: %s\n' % str(self._input_types) + shapes_str = types_str + shapes_str + return prefix + caller + shapes_str + self._traceback + + +def _apply_to_first_n(f, x, n): + """Helper: apply f to first n elements on the stack x if n > 0.""" + if n < 1: + return f(x) + argument, rest = x[:n], x[n:] + if n == 1: + argument = argument[0] + result = f(argument) + if not rest: + return result + if n == 1: + result = [result] + result = list(result) + list(rest) + if isinstance(x, tuple): + result = tuple(result) + return result + + +def nested_reduce(x, f): + """Fold the function f to the nested structure x (dicts, tuples, lists).""" + if isinstance(x, list): + return f([nested_reduce(y, f) for y in x]) + if isinstance(x, tuple): + return f([nested_reduce(y, f) for y in x]) + return x + + +def shapes(x): + """Get a structure of shapes for a structure of nested arrays.""" + def shape(x): + try: + return tuple([int(i) for i in x.shape]) + except Exception: # pylint: disable=broad-except + return [] + return nested_map(x, shape) + + +def sizes(x): + """Get a structure of sizes for a structure of nested arrays.""" + def size(x): + try: + return x.size + except Exception: # pylint: disable=broad-except + return 0 + return nested_map(x, size) + + +def _find_frame(stack, start=0): + """Find the frame with the caller on the stack.""" + # We want to find the first place where the layer was called + # that is *not* an __init__ function of an inheriting layer. + frame = inspect.getframeinfo(stack[start][0]) + # If we are in an init, move on. + if frame.function == '__init__': + return _find_frame(stack, start + 1) + return frame + + +def _shorten_file_path(line): + """Shorten file path in error lines for more readable tracebacks.""" + start = line.lower().find('file') + if start < 0: + return line + first_quote = line.find('"', start) + if first_quote < 0: + return line + second_quote = line.find('"', first_quote + 1) + if second_quote < 0: + return line + path = line[first_quote + 1:second_quote] + new_path = '/'.join(path.split('/')[-3:]) + return line[:first_quote] + '[...]/' + new_path + line[second_quote + 1:] + + +def _short_traceback(skip=3): + """Cleaned-up form of traceback.""" + counter, res = 0, [] + # Skipping 3 lines by default: the top (useless) and self-call. + lines = traceback.format_exc().splitlines()[skip:] + for l in lines: + res.append(_shorten_file_path(l)) + if counter % 2 == 1: + res.append('') + counter += 1 + # If we see a LayerError, the traceback has already been processed. + if l.startswith('LayerError'): + # Skip 4 back except last as these are internal base-layer calls. + res = res[:-4] + [res[-1]] + res += lines[counter:] + break + return '\n'.join(res) + + +def _validate_forward_input(x, n_inputs): + if n_inputs != 1: + if not isinstance(x, tuple): + raise TypeError( + 'expected input to be a tuple; instead received {}'.format(type(x))) + if len(x) != n_inputs: + raise ValueError( + 'input tuple length ({}) does not equal required number of inputs' + ' ({})'.format(len(x), n_inputs)) + + +def layer(n_inputs=1, n_outputs=1, new_params_and_state_fn=None): + """Returns a decorator that converts a function into a Layer class builder.""" + + def _build_layer_class(raw_fn): + """Returns a Layer class whose callable instances execute the function.""" + + def _init(self, **kwargs): + self._kwargs = kwargs # pylint: disable=protected-access + Layer.__init__(self, n_inputs=n_inputs, n_outputs=n_outputs) + + def _new_params_and_state(self, input_shapes, input_dtype, rng): + if new_params_and_state_fn is None: + return (), () + kwargs = self._kwargs # pylint: disable=protected-access + return new_params_and_state_fn(input_shapes, input_dtype, rng, **kwargs) + + def _is_empty(raw_output): + return raw_output is None or (isinstance(raw_output, (list, tuple)) + and len(raw_output) == 0) # pylint: disable=g-explicit-length-test + + def _forward(self, x, params=(), state=(), **kwargs): + """Uses this layer as part of a forward pass through the model.""" + merged_kwargs = kwargs.copy() + merged_kwargs.update(self._kwargs) # pylint: disable=protected-access + + _validate_forward_input(x, n_inputs) + raw_output = raw_fn(x, params=params, **merged_kwargs) + output = () if _is_empty(raw_output) else raw_output + return (output, state) + + # Set docstrings and create the class. + _forward.__doc__ = raw_fn.__doc__ + _new_params_and_state.__doc__ = new_params_and_state_fn.__doc__ + # Note: None.__doc__ is None + cls = type(raw_fn.__name__, (Layer,), + {'__init__': _init, + 'forward': _forward, + 'new_params_and_state': _new_params_and_state}) + return cls + + return _build_layer_class + + +def _random_values(input_shapes, rng, integer_inputs=False): + """Creates random floats or ints of the given shape. + + Args: + input_shapes: A tuple representing a shape (if the layer takes one input) + or a tuple of shapes (if this layer takes more than one input). + For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)). + rng: A random number generator. + integer_inputs: If True, use numpy int32 to produce the random data, else + use float32. + + Returns: + Random values with the shape and type specified. + """ + if isinstance(input_shapes[0], int): + # Non-nested shape, create a random tuple. + if not integer_inputs: + return backend.random.uniform(rng, input_shapes, minval=-1.0, maxval=1.0) + return backend.random.bernoulli(rng, 0.5, input_shapes).astype(onp.int32) + elif isinstance(input_shapes, tuple): # Nested shape: tuple. + return tuple(_random_values(x, rng, integer_inputs) for x in input_shapes) + else: + raise TypeError(type(input_shapes)) + + +def _is_tuple_of_shapes(shape): + # TODO(jonni): Find better way to distinguish a shape from a tuple of shapes. + if not isinstance(shape, tuple): + raise TypeError('shape must be a tuple or tuple of tuples, instead got:' + ' {}'.format(shape)) + return isinstance(shape, tuple) and isinstance(shape[0], tuple) + + +def check_shape_agreement(layer_obj, input_shapes, integer_inputs=False): + """Checks if the layer's call output agrees its pseudo_forward predictions. + + This function helps test layer mechanics and inter-layer connections that + aren't dependent on specific data values. + + Args: + layer_obj: A Layer instance. + input_shapes: A tuple representing a shape (if the layer takes one input) + or a tuple of shapes (if this layer takes more than one input). + For example: (210, 160, 3) or ((210, 160, 3), (105, 80, 3)). + integer_inputs: If True, use numpy int32 as the type for the pseudo-data, + else use float32. + + Returns: + A tuple representing either a single shape (if the layer has one output) or + a tuple of shape tuples (if the layer has more than one output). + """ + rng1, rng2, rng3 = backend.random.split(backend.random.get_prng(0), 3) + input_dtype = onp.int32 if integer_inputs else onp.float32 + if _is_tuple_of_shapes(input_shapes): + pseudo_data = tuple(ShapeType(x, input_dtype) for x in input_shapes) + input_dtype = tuple(input_dtype for _ in input_shapes) + else: + pseudo_data = ShapeType(input_shapes, input_dtype) + params, state = layer_obj.initialize_once(input_shapes, input_dtype, rng1) + pseudo_output, _ = layer_obj.pseudo_forward(pseudo_data, params, state) + if isinstance(pseudo_output, tuple): + output_shape = tuple(x.shape for x in pseudo_output) + else: + output_shape = pseudo_output.shape + + random_input = _random_values(input_shapes, rng2, integer_inputs) + real_output = layer_obj(random_input, params=params, state=state, rng=rng3) + result_shape = shapes(real_output) + + msg = 'output shape %s != real result shape %s' % (output_shape, result_shape) + assert output_shape == result_shape, msg + # TODO(jonni): Remove this assert? It makes test logs harder to read. + return output_shape diff --git a/trax/layers/base_test.py b/trax/layers/base_test.py new file mode 100644 index 000000000..e0734b91e --- /dev/null +++ b/trax/layers/base_test.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for base layer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax import backend +from trax.layers import base + + +class BaseLayerTest(absltest.TestCase): + + def test_layer_decorator_and_shape_agreement(self): + @base.layer() + def add_one(x, **unused_kwargs): + return x + 1 + + output_shape = base.check_shape_agreement( + add_one(), (12, 17)) # pylint: disable=no-value-for-parameter + self.assertEqual(output_shape, (12, 17)) + + def test_custom_zero_grad(self): + + class IdWithZeroGrad(base.Layer): + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + return x, () + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, ct, params, state, **kwargs): + return (backend.numpy.zeros_like(ct), ()) + + layer = IdWithZeroGrad() + rng = backend.random.get_prng(0) + input_shape = (9, 17) + random_input = backend.random.uniform(rng, input_shape, minval=-1.0, + maxval=1.0) + layer.initialize_once(input_shape, random_input.dtype, rng) + f = lambda x: backend.numpy.mean(layer(x)) + grad = backend.grad(f)(random_input) + self.assertEqual(grad.shape, input_shape) # Gradient for each input. + self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0. + + def test_custom_id_grad(self): + + class IdWithIdGrad(base.Layer): + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + return x, () + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, ct, params, state, **kwargs): + return (inputs, ()) + + layer = IdWithIdGrad() + rng = backend.random.get_prng(0) + input_shape = (9, 17) + random_input = backend.random.uniform(rng, input_shape, minval=-1.0, + maxval=1.0) + layer.initialize_once(input_shape, random_input.dtype, rng) + f = lambda x: backend.numpy.mean(layer(x)) + grad = backend.grad(f)(random_input) + self.assertEqual(grad.shape, input_shape) # Gradient for each input. + self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input. + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/combinators.py b/trax/layers/combinators.py new file mode 100644 index 000000000..bcd34c711 --- /dev/null +++ b/trax/layers/combinators.py @@ -0,0 +1,539 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Combinators for composing layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import backend +from trax.backend import numpy as np +from trax.layers import base + + +def Model(*layers): + """Ensures that a layer or list of layers can be treated as a model. + + Currently, any subclass of base.Layer can be treated as a model. + + Args: + *layers: One or more layer objects. In fuller detail, the list may contain + nested sublists, and the top-level list can also be a tuple. + + Returns: + A single object that treated as a model, e.g., trained or evaluated. + """ + return Serial(*layers) + + +def _deep_flatten(items): # pylint: disable=invalid-name + """Returns a list of objects, flattening sublists/subtuples along the way. + + Example: _deep_flatten([1, (2, 3, (4, 5), [6, 7]), [[[8]]]]) would return + the list [1, 2, 3, 4, 5, 6, 7, 8]. + + Args: + items: An iterable. If elements of this iterable are lists or tuples, they + will be (recursively) flattened until non-list non-tuple objects are + reached. + + Returns: + A list of non-list, non-tuple objects. + """ + def _flat_gen(xs): # pylint: disable=invalid-name + for x in xs: + if isinstance(x, (list, tuple)): + for y in _flat_gen(x): + yield y + else: + yield x + return list(_flat_gen(items)) + + +def _ensure_sublayers(layers): # pylint: disable=invalid-name + """Ensures that elements in a layer list are layers. + + Args: + layers: A tuple or list whose elements can each be a layer, tuple, or list, + and so on recursively. + + Returns: + An analogous collection of layers in which embedded layer lists are + wrapped in Serial layer instances. + """ + if not layers: # None or an empty list can signal a no-op. + return Serial(None) # no-op, but still handles shapes and initialization + elif isinstance(layers, (list, tuple)): + sublayers_not_lists = [] + for layer in layers: + sublayers_not_lists.append( + Serial(layer) if isinstance(layer, (list, tuple)) else layer) + return sublayers_not_lists + else: + raise TypeError(type(layers)) + + +def _pop_rng_and_split(args_dict, n_copies): # pylint: disable=invalid-name + rng = args_dict.pop('rng', None) + if rng is None: + return (None,) * n_copies + return backend.random.split(rng, n_copies) + + +def _count_items(xs): # pylint: disable=invalid-name + return len(xs) if isinstance(xs, (list, tuple)) else 1 + + +class Serial(base.Layer): + """Combinator that applies layers serially (by function composition). + + A Serial combinator uses stack semantics to manage data for its sublayers. + Each sublayer sees only the inputs it needs and returns only the outputs it + has generated. The sublayers interact via the data stack. For instance, a + sublayer k, following sublayer j, gets called with the data stack in the + state left after layer j has applied. The Serial combinator then: + + - takes N_in items off the top of the stack (N_in = k.n_inputs) and calls + layer k, passing those items as arguments; and + + - takes layer k's N_out return values (N_out = k.n_outputs) and pushes + them onto the data stack. + + A Serial instance with no sublayers acts as a special-case (but useful) + 1-input 1-output no-op. + """ + + def __init__(self, *layers): + super(Serial, self).__init__() + + layers = self._ensure_flat(layers) + self._sublayers = layers + self._n_layers = len(layers) + + if layers: + self._n_inputs, self._n_outputs = self._n_inputs_n_outputs(layers) + + def _ensure_flat(self, layers): + """Ensures that layers is a single flat list of Layer instances.""" + del self + if len(layers) == 1 and layers[0] is None: + layers = () + else: + layers = _deep_flatten(layers) + for obj in layers: + if not isinstance(obj, base.Layer): + raise ValueError( + 'Found nonlayer object ({}) in layers: {}.'.format(obj, layers)) + return layers + + def _n_inputs_n_outputs(self, layers): + del self + running_max = 0 + running_total = 0 + for layer in layers: + running_total += layer.n_inputs + running_max = max(running_max, running_total) + running_total -= layer.n_outputs + return running_max, (running_max - running_total) + + def _validate_forward_inputs(self, xs): + if not isinstance(xs, tuple) and self._n_inputs != 1: + raise TypeError( + 'Serial.forward input must be a tuple; instead got {}'.format(xs)) + len_xs = 1 if isinstance(xs, np.ndarray) else len(xs) + if len_xs < self.n_inputs: + raise ValueError( + 'number of inputs ({}) to Serial.forward less than n_inputs' + ' ({})'.format(len(xs), self.n_inputs)) + + @base.Layer.params.setter + def params(self, params): + """Recursively sets params on this layer and all sublayers.""" + self._params = params + assert len(params) == self._n_layers + for layer, sublayer_params in zip(self.sublayers, params): + layer.params = sublayer_params + + @base.Layer.state.setter + def state(self, state): + """Recursively sets non-param state on this layer and all sublayers.""" + self._state = state + assert len(state) == self._n_layers + for layer, sublayer_state in zip(self.sublayers, state): + layer.state = sublayer_state + + def forward(self, xs, params=(), state=(), **kwargs): + self._validate_forward_inputs(xs) + rngs = _pop_rng_and_split(kwargs, self._n_layers) + if not self.sublayers: # No-op: leave args unchanged. + return (xs, state) + + stack = xs + new_state = [] + n_layers = self._n_layers + if n_layers != 1 and len(params) != n_layers: + raise ValueError('number of params ({}) not equal to number of layers ' + '({})'.format(len(params), n_layers)) + if n_layers != 1 and len(state) != n_layers: + raise ValueError('length of state ({}) not equal to number of layers ' + '({})'.format(len(state), n_layers)) + for layer, p, s, rng in zip(self.sublayers, params, state, rngs): + is_stack_just_one_item = (_count_items(stack) == 1) + + # Give layer its args from the stack; treat 1-arg layer specially. + n_in = layer.n_inputs + if n_in == 1 and is_stack_just_one_item: + inputs = stack + elif n_in == 1: + inputs = stack[0] + else: + inputs = stack[:n_in] + outputs, s = layer.apply_forward(inputs, params=p, state=s, rng=rng, + **kwargs) + new_state.append(s) + + # Push outputs onto remaining stack (if any). + if n_in < _count_items(stack): + if layer.n_outputs == 1: + outputs = (outputs,) + stack = outputs + stack[n_in:] + else: + stack = outputs # NOTE: can be single value or tuple. + + return stack, new_state + + def new_params_and_state(self, input_shape, input_dtype, rng): + def MakeShapeType(shape, dtype): + if isinstance(dtype, (list, tuple)): + return tuple(MakeShapeType(s, t) for s, t in zip(shape, dtype)) + return base.ShapeType(shape=shape, dtype=dtype) + + params = [] + states = [] + pseudo_xs = MakeShapeType(input_shape, input_dtype) + for layer in self.sublayers: + rng, layer_rng = backend.random.split(rng) + + # Give layer its args from pseudo_xs; treat 1-arg layer specially. + is_stack_just_one_item = (_count_items(pseudo_xs) == 1) + n_in = layer.n_inputs + if n_in == 1 and is_stack_just_one_item: + inputs = pseudo_xs + elif n_in == 1: + inputs = pseudo_xs[0] + else: + inputs = pseudo_xs[:n_in] + + in_shape = base.nested_map(inputs, lambda x: x.shape) + in_dtype = base.nested_map(inputs, lambda x: x.dtype) + param, state = layer.initialize_once(in_shape, in_dtype, layer_rng) + pparam = layer._params # pylint: disable=protected-access + + outputs, _ = layer.pseudo_forward(inputs, pparam, state) + + # Push outputs onto remaining pseudo_xs (if any). + if n_in < _count_items(pseudo_xs): + if layer.n_outputs == 1: + outputs = (outputs,) + pseudo_xs = outputs + pseudo_xs[n_in:] + else: + pseudo_xs = outputs # NOTE: can be single value or tuple. + + params.append(param) + states.append(state) + return params, states + + +@base.layer(n_outputs=2) +def Dup(x, **unused_kwargs): + """Duplicates (copies) an element.""" + return (x, x) + + +@base.layer(n_inputs=2, n_outputs=2) +def Swap(xs, **unused_kwargs): + """Swaps two elements.""" + return (xs[1], xs[0]) + + +def Dup2(): + """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...).""" + return Serial([ + # Stack is (a, b, ...) + Parallel(Dup(), Dup()), # pylint: disable=no-value-for-parameter + # Stack is (a, a, b, b, ...) + Parallel([], Swap()), # pylint: disable=no-value-for-parameter + # Stack is (a, b, a, b, ...) + ]) + + +def Dup3(): + """Copy 3 elements of the stack: (a, b, c, ...) -> (a, b, c, a, b, c, ...).""" + return Serial([ + # Stack is (a, b, c, ...) + Parallel(Dup(), Dup(), Dup()), # pylint: disable=no-value-for-parameter + # Stack is (a, a, b, b, c, c, ...) + Parallel([], Swap(), Swap()), # pylint: disable=no-value-for-parameter + # Stack is (a, b, a, c, b, c, ...) + Parallel([], [], Swap()), # pylint: disable=no-value-for-parameter + # Stack is (a, b, c, a, b, c, ...) + ]) + + +@base.layer(n_outputs=0) +def Drop(x, **unused_kwargs): + """Drops one element.""" + del x # Just for the compiler. + return () + + +@base.layer(n_inputs=0) +def FlattenList(xs, **unused_kwargs): + """Flatten lists.""" + # TODO(jonni): Consider renaming layer to DeepFlatten. + return tuple(_deep_flatten(xs)) + + +def _nested_op(inputs, op): # pylint: disable=invalid-name + """Helper: apply op over a list of arrays or nested arrays.""" + # If input is a dictionary, apply to the values (ignore keys). + if isinstance(inputs, dict): + return _nested_op(list(inputs.values()), op) + # First the simple non-nested case. + if not isinstance(inputs[0], (list, tuple)): + return op(inputs) + # In the nested case, sum on each axis separately. + result_list = [] + for i in range(len(inputs[0])): + result_list.append(_nested_op([x[i] for x in inputs], op=op)) + if isinstance(inputs[0], list): + return result_list + return tuple(result_list) + + +@base.layer(n_inputs=2) +def Add(xs, **unused_kwargs): + """Adds two tensors.""" + return xs[0] + xs[1] + + +@base.layer(n_inputs=2) +def SubtractTop(xs, **unused_kwargs): + """Subtracts the first tensor from the second.""" + return xs[1] - xs[0] + + +@base.layer(n_inputs=2) +def Multiply(xs, **unused_kwargs): + """Multiplies two tensors.""" + return xs[0] * xs[1] + + +@base.layer(n_inputs=3) +def Gate(xs, **unused_kwargs): + """Implements a gating function on a (memory, gate, candidate) tuple. + + Final update is memory * gate + (1-gate) * candidate + + This gating equation may also be referred to as Highway Network. + Highway Networks: https://arxiv.org/abs/1505.00387 + + Args: + xs: A tuple of memory, gate, candidate + + Returns: + The result of applying gating. + """ + state, gate, candidate = xs + return gate * state + (1.0 - gate) * candidate + + +class Concatenate(base.Layer): + """Concatenates n tensors into a single tensor.""" + + def __init__(self, n_items=2, axis=-1): + super(Concatenate, self).__init__(n_inputs=n_items) + self._n_items = n_items + self._axis = axis + + def forward(self, xs, params=(), state=(), **kwargs): + del params, kwargs + return backend.numpy.concatenate(xs, self._axis), state + + +class Split(base.Layer): + """Splits the input into sections along an axis.""" + + def __init__(self, n_sections=2, axis=-1): + super(Split, self).__init__(n_outputs=n_sections) + self._n_sections = n_sections + self._axis = axis + + def forward(self, inputs, params=(), state=(), **kwargs): + del params, kwargs + res = tuple(backend.numpy.split(inputs, self._n_sections, self._axis)) + return res, state + + +class Parallel(base.Layer): + """Combinator that applies a list of layers in parallel to its inputs. + + Layers in the list apply to successive spans of inputs, where the spans are + determined how many inputs each layer takes. The resulting output is the + (flattened) concatenation of the resepective layer outputs. + + For example, suppose one has three layers: + + - F: 1 input, 1 output + - G: 3 inputs, 1 output + - H: 2 inputs, 2 outputs (h1, h2) + + Then Parallel(F, G, H) will take 6 inputs and give 4 outputs: + + - inputs: a, b, c, d, e, f + - outputs: F(a), G(b, c, d), h1, h2 + + As an important special case, a None argument to Parallel acts as if it takes + one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For + example: + + Parallel(None, F) + + creates a layer that passes its first input unchanged and applies F to the + following input(s). + """ + + def __init__(self, *layers): + """The constructor. + + Args: + *layers: A list of layers. + + Returns: + A new layer in which each of the given layers applies to its corresponding + span of elements in the dataflow stack. + """ + super(Parallel, self).__init__() + layers = self._validate(layers) + self._n_layers = len(layers) + self._sublayers = layers + self._n_inputs = sum(x.n_inputs for x in layers) + self._n_outputs = sum(x.n_outputs for x in layers) + + def _validate(self, layers): + if not layers or len(layers) < 2: + raise ValueError( + 'layers ({}) must be a list with at least two elements'.format( + layers)) + layers = list(layers) # Ensure we can modify layers. + for i, obj in enumerate(layers): + if obj is None or obj == []: # pylint: disable=g-explicit-bool-comparison + layers[i] = Serial(None) + elif isinstance(obj, (list, tuple)): + layers[i] = Serial(obj) + else: + if not isinstance(obj, base.Layer): + raise ValueError( + 'Found nonlayer object ({}) in layers list: [{}].'.format( + obj, layers)) + if layers[i].n_inputs == 0: + raise ValueError( + 'Sublayer with n_inputs = 0 not allowed in Parallel:' + ' {}'.format(layers[i])) + return layers + + def _allot_to_sublayers(self, inputs): + """Divides Parallel's inputs for use by the sublayers. + + Args: + inputs: Tuple of elements. + + Returns: + A tuple that partitions this layer's inputs among its sublayers. + Sublayers that take one argument get that argument directly. All other + sublayers get a tuple of items. + """ + start, end = 0, 0 + sub_inputs = [] + for layer in self.sublayers: + n_in = layer.n_inputs + end = start + n_in + if n_in == 1: + sub_inputs.append(inputs[start]) + else: + sub_inputs.append(inputs[start:end]) + start = end + return tuple(sub_inputs) + + @base.Layer.params.setter + def params(self, params): + """Recursively sets params on this layer and all sublayers.""" + self._params = params + assert len(params) == self._n_layers + for layer, sublayer_params in zip(self.sublayers, params): + layer.params = sublayer_params + + @base.Layer.state.setter + def state(self, state): + """Recursively sets non-param state on this layer and all sublayers.""" + self._state = state + assert len(state) == self._n_layers + for layer, sublayer_state in zip(self.sublayers, state): + layer.state = sublayer_state + + def forward(self, inputs, params=(), state=(), **kwargs): + n_layers, layers = self._n_layers, self.sublayers + sublayer_inputs = self._allot_to_sublayers(inputs) + rngs = _pop_rng_and_split(kwargs, n_layers) + assert len(sublayer_inputs) == n_layers + assert len(params) == n_layers + assert len(state) == n_layers + assert len(rngs) == n_layers + outputs = [] + new_state = [] + for layer, x, p, s, r in zip(layers, sublayer_inputs, params, state, rngs): + # Note that zip silently truncates its result if lengths don't match. + sub_outputs, sub_state = layer.apply_forward(x, params=p, state=s, rng=r, + **kwargs) + if layer.n_outputs == 1: + outputs.append(sub_outputs) + else: + outputs.extend(sub_outputs) + new_state.append(sub_state) + output = outputs[0] if self.n_outputs == 1 else tuple(outputs) + return output, new_state + + def new_params_and_state(self, input_shapes, input_dtypes, rng): + sublayer_shapes = self._allot_to_sublayers(input_shapes) + sublayer_dtypes = self._allot_to_sublayers(input_dtypes) + rngs = backend.random.split(rng, self._n_layers) + inits = [layer.initialize_once(shape, dtype, rng) + for layer, shape, dtype, rng + in zip(self.sublayers, sublayer_shapes, sublayer_dtypes, rngs)] + if not inits: + return (), () + else: + return tuple(zip(*inits)) + + +def Residual(*layers, **kwargs): + """Constructs a residual version of layers, summing input to layers output.""" + shortcut = kwargs.get('shortcut') # default None signals no-op + return [ + Dup(), # pylint: disable=no-value-for-parameter + Parallel(shortcut, layers), + Add(), # pylint: disable=no-value-for-parameter + ] diff --git a/trax/layers/combinators_test.py b/trax/layers/combinators_test.py new file mode 100644 index 000000000..67cca8c7c --- /dev/null +++ b/trax/layers/combinators_test.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for combinator layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.layers import combinators as cb +from trax.layers import core + + +class CombinatorLayerTest(absltest.TestCase): + + def test_drop(self): + layer = cb.Drop() + input_shape = (3, 2) + expected_shape = () + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_dup(self): + layer = cb.Dup() + input_shape = (3, 2) + expected_shape = ((3, 2), (3, 2)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_swap(self): + layer = cb.Swap() + input_shape = ((3, 2), (4, 7)) + expected_shape = ((4, 7), (3, 2)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_serial_no_op(self): + layer = cb.Serial(None) + input_shape = ((3, 2), (4, 7)) + expected_shape = ((3, 2), (4, 7)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_serial_no_op_list(self): + layer = cb.Serial([]) + input_shape = ((3, 2), (4, 7)) + expected_shape = ((3, 2), (4, 7)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_serial_one_in_one_out(self): + layer = cb.Serial(core.Div(divisor=2.0)) + input_shape = (3, 2) + expected_shape = (3, 2) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_serial_div_div(self): + layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0)) + input_shape = (3, 2) + expected_shape = (3, 2) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_serial_dup_dup(self): + layer = cb.Serial(cb.Dup(), cb.Dup()) + input_shape = (3, 2) + expected_shape = ((3, 2), (3, 2), (3, 2)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_parallel_dup_dup(self): + layer = cb.Parallel(cb.Dup(), cb.Dup()) + input_shape = ((3, 2), (4, 7)) + expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_parallel_div_div(self): + layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0)) + input_shape = ((3, 2), (4, 7)) + expected_shape = ((3, 2), (4, 7)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_parallel_no_ops(self): + layer = cb.Parallel([], None) + input_shape = ((3, 2), (4, 7)) + expected_shape = ((3, 2), (4, 7)) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_branch_op_not_defined(self): + with self.assertRaises(AttributeError): + cb.Branch([], []) + + def test_select_op_not_defined(self): + input_shape = ((3, 2), (4, 7)) + with self.assertRaises(AttributeError): + cb.Select(1, input_shape) + +if __name__ == '__main__': + absltest.main() diff --git a/trax/layers/convolution.py b/trax/layers/convolution.py new file mode 100644 index 000000000..ba7eec703 --- /dev/null +++ b/trax/layers/convolution.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax convolution layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import operator + +import six + +from trax import backend +from trax.backend import numpy as np +from trax.layers import base +from trax.layers import initializers as init + + +class Conv(base.Layer): + """Layer constructor function for a general convolution layer.""" + + def __init__(self, filters, kernel_size, strides=None, padding='VALID', + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6)): + super(Conv, self).__init__() + self._filters = filters + self._kernel_size = kernel_size + self._padding = padding + self._dimension_numbers = dimension_numbers + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers + self._one = (1,) * len(kernel_size) + self._strides = strides or self._one + self._bias_initializer = bias_initializer + rhs_spec = self._rhs_spec + self._kernel_initializer = kernel_initializer + if kernel_initializer is None: + self._kernel_initializer = init.GlorotNormalInitializer( + rhs_spec.index('O'), rhs_spec.index('I')) + + def _check_nhwc(self): + msg = 'Convolutions on more than 4 dimensions only supported in NHWC.' + assert self._lhs_spec == self._out_spec == 'NHWC', msg + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + w, b = params + x_shape = list(x.shape) + if len(x_shape) > 4: + self._check_nhwc() + new_batch_dim = six.moves.reduce(operator.mul, x_shape[:-3]) + x = np.reshape(x, [new_batch_dim] + x_shape[-3:]) + res = backend.conv( + x, w, self._strides, self._padding, self._dimension_numbers, + self._one) + b + if len(x_shape) > 4: + res = np.reshape(res, x_shape[:-3] + list(res.shape[-3:])) + return res, state + + def _kernel_shape(self, input_shape): + """Helper to calculate the kernel shape.""" + kernel_size_iter = iter(self._kernel_size) + return [self._filters if c == 'O' else + input_shape[self._lhs_spec.index('C')] if c == 'I' else + next(kernel_size_iter) for c in self._rhs_spec] + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype + if len(input_shape) > 4: + self._check_nhwc() + new_batch_dim = six.moves.reduce(operator.mul, input_shape[:-3]) + input_shape = [new_batch_dim] + list(input_shape[-3:]) + kernel_shape = self._kernel_shape(input_shape) + bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) + w = self._kernel_initializer(kernel_shape, rng) + b = self._bias_initializer(bias_shape, rng) + return (w, b), () + + +class CausalConv(Conv): + """Causal (masked) convolution for [batch x time x depth] sequences. + + Maintains causality along time axis. Used in language modeling tasks. + """ + + def __init__(self, + filters, + kernel_width=3, + kernel_initializer=None, + bias_initializer=init.RandomNormalInitializer(1e-6)): + super(CausalConv, self).__init__( + filters=filters, + kernel_size=(kernel_width,), + strides=None, + padding='VALID', + dimension_numbers=('NWC', 'WIO', 'NWC'), + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + + def forward(self, x, params=(), state=(), **kwargs): + assert self._padding == 'VALID' + # Left pad with 0s. Applying an unmasked valid convolution on top of this + # yields a causal convolution. + # TODO(ddohan): Support strided and dilated convolutions. + rate = 1 + effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1) + pad = effective_kernel_size - 1 + x_leftpad = np.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant') + + res = super(CausalConv, self).forward(x_leftpad, params) + return res diff --git a/trax/layers/convolution_test.py b/trax/layers/convolution_test.py new file mode 100644 index 000000000..87e7ccc58 --- /dev/null +++ b/trax/layers/convolution_test.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for convolution layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.layers import convolution + + +class ConvolutionLayerTest(absltest.TestCase): + + def test_conv(self): + input_shape = (29, 5, 5, 20) + result_shape = base.check_shape_agreement( + convolution.Conv(30, (3, 3)), input_shape) + self.assertEqual(result_shape, (29, 3, 3, 30)) + + def test_conv_rebatch(self): + input_shape = (3, 29, 5, 5, 20) + result_shape = base.check_shape_agreement( + convolution.Conv(30, (3, 3)), input_shape) + self.assertEqual(result_shape, (3, 29, 3, 3, 30)) + + +class CausalConvolutionTest(absltest.TestCase): + + def test_causal_conv(self): + input_shape = (29, 5, 20) + conv = convolution.CausalConv(filters=30, kernel_width=3) + result_shape = base.check_shape_agreement(conv, input_shape) + self.assertEqual(result_shape, (29, 5, 30)) + + # TODO(ddohan): How to test for causality? Gradient check between positions? + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/core.py b/trax/layers/core.py new file mode 100644 index 000000000..fb0eb67b9 --- /dev/null +++ b/trax/layers/core.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax layers library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import jax +import numpy as onp + +from trax import backend +from trax.backend import numpy as np +from trax.layers import base +from trax.layers import initializers as init + + +@base.layer() +def Relu(x, **unused_kwargs): + return np.maximum(x, np.zeros_like(x)) + + +@base.layer() +def ParametricRelu(x, a=1., **unused_kwargs): + return np.maximum(a * x, np.zeros_like(x)) + + +@base.layer() +def LeakyRelu(x, a=0.01, **unused_kwargs): + return np.where(x >= 0, x, a * x) + + +@base.layer() +def Elu(x, a=1., **unused_kwargs): + return np.where(x > 0, x, a * np.expm1(x)) + + +@base.layer() +def Selu(x, + alpha=1.6732632423543772848170429916717, + lmbda=1.0507009873554804934193349852946): + return lmbda * np.where(x > 0, x, alpha * np.expm1(x)) + + +@base.layer() +def Gelu(x, **unused_kwargs): + return x * backend.erf(x) + + +@base.layer() +def Sigmoid(x, **unused_kwargs): + return backend.expit(x) + + +@base.layer() +def Tanh(x, **unused_kwargs): + return np.tanh(x) + + +@base.layer() +def HardSigmoid(x, **unused_kwargs): + """Linear approximation to sigmoid.""" + return np.maximum(0, np.minimum(1, (1 + x))) + + +@base.layer() +def HardTanh(x, **unused_kwargs): + """Linear approximation to tanh.""" + return np.maximum(-1, np.minimum(1, x)) + + +@base.layer() +def Exp(x, **unused_kwargs): + return np.exp(x) + + +@base.layer() +def LogSoftmax(x, axis=-1, **unused_kwargs): + """Apply log softmax to x: log-normalize along the given axis.""" + return x - backend.logsumexp(x, axis, keepdims=True) + + +@base.layer() +def Softmax(x, axis=-1, **unused_kwargs): + """Apply softmax to x: exponentiate and normalize along the given axis.""" + return np.exp(x - backend.logsumexp(x, axis, keepdims=True)) + + +@base.layer() +def Softplus(x, **unused_kwargs): + return np.logaddexp(x, 0.) + + +@base.layer() +def ToFloat(x, **unused_kwargs): + return x.astype(onp.float32) + + +class Dense(base.Layer): + """A dense (a.k.a. fully-connected, affine) layer.""" + + def __init__(self, + n_units, + kernel_initializer=init.GlorotUniformInitializer(), + bias_initializer=init.RandomNormalInitializer(1e-6)): + super(Dense, self).__init__() + self._n_units = n_units + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + w, b = params + return np.dot(x, w) + b, state + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_dtype + rng1, rng2 = backend.random.split(rng, 2) + w = self._kernel_initializer((input_shape[-1], self._n_units), rng1) + b = self._bias_initializer((self._n_units,), rng2) + return (w, b), () + + +class Embedding(base.Layer): + """Layer constructor function for an embedding layer.""" + + def __init__(self, + d_feature, + vocab_size, + kernel_initializer=init.GlorotUniformInitializer()): + super(Embedding, self).__init__() + self._d_feature = d_feature # feature dimensionality + self._vocab_size = vocab_size + self._kernel_initializer = kernel_initializer + + def forward(self, x, params=(), state=(), **kwargs): + del kwargs + return np.take(params, x, axis=0), state + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_shape, input_dtype + out_dim = (self._vocab_size, self._d_feature) + params = self._kernel_initializer(out_dim, rng) + return params, () + + +# Flatten. +@base.layer() +def Flatten(x, n_axes_to_keep=1, **unused_kwargs): + if n_axes_to_keep >= len(x.shape): + raise ValueError("n_axes_to_keep[%d] should be less than input's rank[%d]" % + (n_axes_to_keep, len(x.shape))) + return np.reshape(x, (x.shape[:n_axes_to_keep] + (-1,))) + + +class Dropout(base.Layer): + """Dropout.""" + + def __init__(self, rate=0.0, name='dropout', mode='train'): + super(Dropout, self).__init__() + self._initial_rate = rate + # TODO(lukaszkaiser): remove the name property by the end of September'19. + # It's only needed for a specific purpose in the short term, will go. + self._name = 'dropout_' + name + self._mode = mode + + def new_params_and_state(self, input_shape, input_dtype, rng): + del input_shape, input_dtype, rng + params = () + state = {self._name: np.array(self._initial_rate)} + return params, state + + def forward(self, x, params=(), state=(), rng=None, **kwargs): + """Execute dropout.""" + del kwargs + rate = self._initial_rate + if isinstance(state, dict) and self._name in state: + rate = state[self._name] + if rng is None: + msg = ('Dropout layer requires apply_fn to be called with a rng keyword ' + 'argument. That is, instead of `Dropout(params, inputs)`, call ' + 'it like `Dropout(params, inputs, rng=key)`.') + raise ValueError(msg) + if self._mode != 'train': + return x, state + keep = backend.random.bernoulli(rng, 1.0 - rate, x.shape) + return np.where(keep, x / (1.0 - rate), np.zeros_like(x)), state + + +@base.layer() +def Div(x, divisor=1.0, **unused_kwargs): + return x / divisor + + +@base.layer() +def AddConstant(x, constant=0.0, **unused_kwargs): + return x + constant + + +@base.layer() +def MulConstant(x, constant=1.0, **unused_kwargs): + return x * constant + + +def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name + """Make a n+1 dim one-hot array from n dim int-categorical array.""" + arange_size = np.arange(size) + if backend.get_name() == 'jax': + # Work around a jax broadcasting issue. + arange_size = jax.lax.tie_in(x, arange_size) + return np.array(x[..., np.newaxis] == arange_size, dtype) + + +# Mean. +@base.layer() +def Mean(x, axis=-1, keepdims=False, **unused_kwargs): + return np.mean(x, axis=axis, keepdims=keepdims) + + +def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name + """Compute log N(x | mu, sigma).""" + a = mu.shape[-1] * np.log(2 * np.pi) + _, b = np.linalg.slogdet(sigma) + y = np.linalg.solve(sigma, x - mu) + y = np.expand_dims(y, axis=-1) + xm = np.expand_dims(x - mu, axis=-2) + c = np.matmul(xm, y) + c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) + + +def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name + """Compute log N(x | mu, eye(diag_sigma)).""" + a = mu.shape[-1] * np.log(2 * np.pi) + b = np.sum(np.log(diag_sigma), axis=-1) + y = x - mu / diag_sigma + y = np.expand_dims(y, axis=-1) + xm = np.expand_dims(x - mu, axis=-2) + c = np.matmul(xm, y) + c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) + return -0.5 * (a + b + c) + + +def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name + """Compute mixture of gaussians loss.""" + ndims = targets.shape[-1] + logits = preds[:, :ngauss] + mus = preds[:, ngauss:ngauss*(ndims + 1)] + sigmas = preds[:, ngauss(ndims + 1):] + sigmas = sigmas * sigmas + 1e-6 # Make positive. + loglogits = logits - backend.logsumexp(logits, axis=-1, keepdims=True) + mus = np.reshape(mus, [-1, ngauss, ndims]) + sigmas = np.reshape(sigmas, [-1, ngauss, ndims]) + targets = np.reshape(targets, [-1, 1, ndims]) + glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) + return backend.logsumexp(loglogits + glogprobs, axis=-1) diff --git a/trax/layers/core_test.py b/trax/layers/core_test.py new file mode 100644 index 000000000..134b6d0d3 --- /dev/null +++ b/trax/layers/core_test.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for core layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +import numpy as onp +from trax import backend +from trax.layers import base +from trax.layers import combinators +from trax.layers import core + + +class CoreLayerTest(absltest.TestCase): + + def test_flatten_n(self): + input_shape = (29, 87, 10, 20, 30) + + layer = core.Flatten() + expected_shape = (29, 87 * 10 * 20 * 30) + actual_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(actual_shape, expected_shape) + + layer = core.Flatten(n_axes_to_keep=2) + expected_shape = (29, 87, 10 * 20 * 30) + actual_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(actual_shape, expected_shape) + + layer = core.Flatten(n_axes_to_keep=3) + expected_shape = (29, 87, 10, 20 * 30) + actual_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(actual_shape, expected_shape) + + layer = core.Flatten(n_axes_to_keep=4) + expected_shape = (29, 87, 10, 20, 30) + actual_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(actual_shape, expected_shape) + + # Not enough dimensions. + with self.assertRaises(base.LayerError): + base.check_shape_agreement(core.Flatten(n_axes_to_keep=5), input_shape) + + with self.assertRaises(base.LayerError): + base.check_shape_agreement(core.Flatten(n_axes_to_keep=6), input_shape) + + def test_div(self): + layer = core.Div(divisor=2.0) + input_np = onp.array([[1, 2, 3], [4, 5, 6]], dtype=onp.float32) + output_np = layer(input_np) + # absltest doesn't have ndarray equalities. + expected_output_np = input_np / 2.0 + self.assertAlmostEqual( + 0.0, + onp.sum((output_np - expected_output_np) ** 2), + delta=1e-6) + + def test_div_shapes(self): + layer = core.Div(divisor=2.0) + input_shape = (3, 2) + expected_shape = (3, 2) + output_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, expected_shape) + + def test_dense_param_sharing(self): + model1 = combinators.Serial(core.Dense(32), core.Dense(32)) + layer = core.Dense(32) + model2 = combinators.Serial(layer, layer) + + rng1, rng2 = backend.random.split(backend.random.get_prng(0), 2) + params1, _ = model1.initialize_once((1, 32), onp.float32, rng1) + params2, _ = model2.initialize_once((1, 32), onp.float32, rng2) + # The first parameters have 2 kernels of size (32, 32). + self.assertEqual((32, 32), params1[0][0].shape) + self.assertEqual((32, 32), params1[1][0].shape) + # The second parameters have 1 kernel of size (32, 32) and an empty dict. + self.assertEqual((32, 32), params2[0][0].shape) + self.assertEqual((), params2[1]) + + def test_dropout(self): + input_shape = (8, 7, 9) + output_shape = (8, 7, 9) + final_shape = base.check_shape_agreement( + core.Dropout(rate=0.1, mode="train"), input_shape) + self.assertEqual(final_shape, output_shape) + final_shape = base.check_shape_agreement( + core.Dropout(rate=0.1, mode="eval"), input_shape) + self.assertEqual(final_shape, output_shape) + + def test_log_gaussian_pdf(self): + x = onp.zeros((2, 5), dtype=onp.float32) + mu = x + dsigma = onp.eye(5)[None, :, :] + sigma = onp.concatenate([dsigma, 2*dsigma], axis=0) + prob = core.log_gaussian_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + + def test_log_gaussian_diag_pdf(self): + x = onp.zeros((2, 5), dtype=onp.float32) + mu = x + sigma = onp.ones((5,))[None, :] + sigma = onp.concatenate([sigma, 2*sigma], axis=0) + prob = core.log_gaussian_diag_pdf(x, mu, sigma) + self.assertEqual(prob.shape, (2,)) + self.assertEqual(int(prob[0]), -4) + self.assertEqual(int(prob[1]), -6) + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/initializers.py b/trax/layers/initializers.py new file mode 100644 index 000000000..28b8363fe --- /dev/null +++ b/trax/layers/initializers.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as onp +from trax import backend + + +def _GetFans(shape, out_dim=-1, in_dim=-2): + """Get the fan-in and fan-out sizes for the given shape and dims.""" + # Temporary fix until numpy.delete supports negative indices. + if out_dim < 0: + out_dim += len(shape) + if in_dim < 0: + in_dim += len(shape) + + receptive_field = backend.numpy.prod(onp.delete(shape, [in_dim, out_dim])) + if len(shape) >= 2: + fan_in, fan_out = shape[in_dim], shape[out_dim] + elif len(shape) == 1: + fan_in = shape[0] + fan_out = shape[0] + else: + fan_in = 1. + fan_out = 1. + fan_in *= receptive_field + fan_out *= receptive_field + return fan_in, fan_out + + +def RandomNormalInitializer(stddev=1e-2): + """An initializer function for random normal coefficients.""" + + def Init(shape, rng): + return (stddev * backend.random.normal(rng, shape)).astype('float32') + + return Init + + +def RandomUniformInitializer(lim=1.0): + """An initializer function for random uniform coefficients.""" + + def Init(shape, rng): + return (backend.random.uniform(rng, shape, backend.numpy.float32, -lim, + lim)) + + return Init + + +def VarianceScalingInitializer(out_dim, in_dim, scale, mode, distribution): + """Initializer capable of adapting its scale to the shape of weights.""" + if scale <= 0.: + raise ValueError('scale must be positive float, {} given'.format(scale)) + if mode not in {'fan_in', 'fan_out', 'fan_avg'}: + raise ValueError( + 'Invalid mode argument:, {}, must be either fan_in, fan_out or fan_avg' + .format(mode)) + + def Init(shape, rng): + """The initializer function.""" + fan_in, fan_out = _GetFans(shape, out_dim, in_dim) + gain = scale + if mode == 'fan_in': + gain /= fan_in + elif mode == 'fan_out': + gain /= fan_out + elif mode == 'fan_avg': + gain /= (fan_in + fan_out) / 2 + if distribution == 'truncated_normal': + # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + stddev = backend.numpy.sqrt(gain) / .87962566103423978 + return (backend.random.truncated_normal(rng, -2, 2, shape) * + stddev).astype('float32') + elif distribution == 'normal': + return (backend.random.normal(rng, shape) * + backend.numpy.sqrt(gain)).astype('float32') + elif distribution == 'uniform': + lim = backend.numpy.sqrt(3. * gain) + return (backend.random.uniform(rng, shape, backend.numpy.float32, -lim, + lim)) + else: + raise ValueError('invalid distribution for variance scaling Initializer') + + return Init + + +def GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): + """An initializer function for random Glorot-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_avg', 'normal') + + +def GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): + """An initializer function for random uniform Glorot-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_avg', + 'uniform') + + +def LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.): + """An initializer function for random LeCun-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_in', 'normal') + + +def LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.): + """An initializer function for random uniform LeCun-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, scale, 'fan_in', 'uniform') + + +def KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.): + """An initializer function for random Kaiming-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, + 2.0 / backend.numpy.sqrt(1 + param**2), + 'fan_in', 'normal') + + +def KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.): + """An initializer function for random uniform Kaiming-scaled coefficients.""" + return VarianceScalingInitializer(out_dim, in_dim, + 2.0 / backend.numpy.sqrt(1 + param**2), + 'fan_in', 'uniform') + + +def OrthogonalInitializer(stddev=1.0): + """Orthogonal Initializer.""" + def Init(shape, rng): + """The orthogonal initializer function.""" + # Have at least 2 elements in shape. + cur_shape = list(shape) + while len(cur_shape) < 2: + cur_shape = [1] + cur_shape + + # Flatten the input shape with the last dimension remaining. + n_rows = 1 + for dim in cur_shape[:-1]: + n_rows *= dim + n_cols = cur_shape[-1] + flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) + + # Generate a random matrix + a = backend.random.normal(rng, flat_shape, dtype=backend.numpy.float32) + + # Compute the qr factorization + q, r = backend.numpy.linalg.qr(a) + + # Make Q uniform + d = backend.numpy.diag(r) + q *= backend.numpy.sign(d) + + # Transpose and reshape back q if needed. + if n_rows < n_cols: + q = backend.numpy.transpose(q) + q = backend.numpy.reshape(q, shape) + + # Return scaled as requested. + return stddev * q + + return Init diff --git a/trax/layers/initializers_test.py b/trax/layers/initializers_test.py new file mode 100644 index 000000000..2263fa47b --- /dev/null +++ b/trax/layers/initializers_test.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for initializers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.backend import random +from trax.layers import initializers + + +class InitializersTest(absltest.TestCase): + + def test_random_normal(self): + initializer = initializers.RandomNormalInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_lecun_uniform(self): + initializer = initializers.LeCunUniformInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_random_uniform(self): + initializer = initializers.RandomUniformInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_glorot_normal(self): + initializer = initializers.GlorotNormalInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_glorot_uniform(self): + initializer = initializers.GlorotUniformInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_lecun_normal(self): + initializer = initializers.LeCunNormalInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_kaiming_normal(self): + initializer = initializers.KaimingNormalInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_kaiming_uniform(self): + initializer = initializers.KaimingUniformInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + + def test_orthogonal(self): + initializer = initializers.OrthogonalInitializer() + input_shape = (29, 5, 7, 20) + init_value = initializer(input_shape, random.get_prng(0)) + self.assertEqual(tuple(init_value.shape), input_shape) + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/intro.ipynb b/trax/layers/intro.ipynb new file mode 100644 index 000000000..7287410f4 --- /dev/null +++ b/trax/layers/intro.ipynb @@ -0,0 +1,834 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7yuytuIllsv1" + }, + "source": [ + "# A Conceptual, Practical Introduction to Trax Layers\n", + "\n", + "This notebook introduces the core concepts and programming components of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n", + " - **layers**: the basic building blocks and how to combine them into networks\n", + " - **data flows, data stack**: how the Trax runtime moves data through the layers\n", + " - **models**: how to train, evaluate, and run predictions with Trax models\n", + " - **new layer classes**: how to define and test your own Layer classes\n", + "\n", + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BIl27504La0G" + }, + "source": [ + "## General Setup\n", + "Execute the following few cells (once) before running any of the code samples in this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2018 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import numpy as onp\n", + "\n", + "\n", + "\n", + "# Import Trax\n", + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "cellView": "both", + "colab": { + "height": 51 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 340, + "status": "ok", + "timestamp": 1570493480872, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "vlGjGoGMTt-D", + "outputId": "26e89335-17c2-4eed-90a2-a372fbfe1d12" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/bin/sh: pip: command not found\n", + "/bin/sh: pip: command not found\n" + ] + } + ], + "source": [ + "#@title Run for installation.\n", + "\n", + "! pip install -q -U trax\n", + "! pip install -q tensorflow\n", + "\n", + "from trax import backend\n", + "from trax import layers as tl" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bYWNWL9MJHv9" + }, + "outputs": [], + "source": [ + "onp.set_printoptions(precision=3) # Less visual noise in the numerical outputs.\n", + "\n", + "def show_layer_properties(layer_obj, layer_name):\n", + " template = ('{}.n_inputs: {}\\n'\n", + " '{}.n_outputs: {}\\n'\n", + " '{}.sublayers: {}\\n'\n", + " '{}.params: {}\\n')\n", + " print(template.format(layer_name, layer_obj.n_inputs,\n", + " layer_name, layer_obj.n_outputs,\n", + " layer_name, layer_obj.sublayers,\n", + " layer_name, layer_obj.params)) " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-LQ89rFFsEdk" + }, + "source": [ + "# Layers\n", + "\n", + "The Layer class represents Trax's concept of a layer, as summarized in the start of the class's docstring:\n", + "```\n", + "class Layer(object):\n", + " \"\"\"Base class for composable layers in a deep learning network.\n", + "\n", + " Layers are the basic building blocks for deep learning models. A Trax layer\n", + " computes a function from zero or more inputs to zero or more outputs,\n", + " optionally using trainable parameters (common) and non-parameter state (not\n", + " common). Authors of new layer subclasses typically override at most two\n", + " methods of the base `Layer` class:\n", + "\n", + " forward(inputs, params=(), state=(), **kwargs):\n", + " Computes this layer's output as part of a forward pass through the model.\n", + "\n", + " new_params_and_state(self, input_shape, input_dtype, rng):\n", + " Returns a (params, state) pair suitable for initializing this layer.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LyLVtdxorDPO" + }, + "source": [ + "## A layer computes a function." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ntZ4_eNQldzL" + }, + "source": [ + "A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects wrapping NumPy arrays.\n", + "\n", + "The simplest layers, those with no parameters, state or sublayers, can be used without initialization. You can think of them (and test them) like simple mathematical functions. For ease of testing and interactive exploration, layer\n", + "objects implement the `__call__ ` method, so you can call them directly on input data:\n", + "```\n", + "y = layer(x)\n", + "```\n", + "\n", + "Layers are also objects, so you can inspect their properties. For example:\n", + "```\n", + "print('Number of inputs required by this layer: {}'.format(layer.n_inputs))\n", + "```\n", + "\n", + "### Example 1. tl.Relu [1 input, 1 output]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1671, + "status": "ok", + "timestamp": 1570493482869, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "V09viOSEQvQe", + "outputId": "6821a3fd-850c-4d10-c233-0617385262e1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "relu(x):\n", + "[[0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 1. 2.]\n", + " [3. 4. 5. 6. 7.]]\n", + "\n", + "number of inputs expected by this layer: 1\n", + "number of outputs promised by this layer: 1\n" + ] + } + ], + "source": [ + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "# Create a layer object (a Relu instance) and apply the layer to data x.\n", + "relu = tl.Relu()\n", + "y = relu(x)\n", + "\n", + "# Show input, output, and two layer properties.\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'relu(x):\\n{}\\n\\n'\n", + " 'number of inputs expected by this layer: {}\\n'\n", + " 'number of outputs promised by this layer: {}')\n", + "print(template.format(x, y, relu.n_inputs, relu.n_outputs))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7sYxIT8crFVE" + }, + "source": [ + "### Example 2. tl.Concatenate [2 inputs, 1 output]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "height": 442 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1056, + "status": "ok", + "timestamp": 1570493483938, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "LMPPNWXLoOZI", + "outputId": "bbdafa45-49e9-467a-8970-a5d9da406045" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x1:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "x2:\n", + "[[-70. -60. -50. -40. -30.]\n", + " [-20. -10. 0. 10. 20.]\n", + " [ 30. 40. 50. 60. 70.]]\n", + "\n", + "concat0([x1, x2]):\n", + "[[ -7. -6. -5. -4. -3.]\n", + " [ -2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]\n", + " [-70. -60. -50. -40. -30.]\n", + " [-20. -10. 0. 10. 20.]\n", + " [ 30. 40. 50. 60. 70.]]\n", + "\n", + "concat1([x1, x2]):\n", + "[[ -7. -6. -5. -4. -3. -70. -60. -50. -40. -30.]\n", + " [ -2. -1. 0. 1. 2. -20. -10. 0. 10. 20.]\n", + " [ 3. 4. 5. 6. 7. 30. 40. 50. 60. 70.]]\n", + "\n", + "concat0: Concatenate{in=2,out=1}\n", + "concat1: Concatenate{in=2,out=1}\n" + ] + } + ], + "source": [ + "x1 = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "x2 = 10 * x1\n", + "\n", + "concat0 = tl.Concatenate(axis=0)\n", + "concat1 = tl.Concatenate(axis=1)\n", + "\n", + "y0 = concat0([x1, x2])\n", + "y1 = concat1([x1, x2])\n", + "\n", + "template = ('x1:\\n{}\\n\\n'\n", + " 'x2:\\n{}\\n\\n'\n", + " 'concat0([x1, x2]):\\n{}\\n\\n'\n", + " 'concat1([x1, x2]):\\n{}\\n')\n", + "print(template.format(x1, x2, y0, y1))\n", + "\n", + "# Print abbreviated object representations (useful for debugging).\n", + "print('concat0: {}'.format(concat0))\n", + "print('concat1: {}'.format(concat1))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1oZv3R8bRMvF" + }, + "source": [ + "## Layers are trainable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3d64M7wLryji" + }, + "source": [ + "Most layer types are trainable: they include parameters that modify the computation of outputs from inputs, and they use back-progagated gradients to update those parameters.\n", + "\n", + "Before use, trainable layers must have their parameters initialized, typically using a PRNG (pseudo-random number generator) key for random number generation. Trax's model trainers take care of this behind the scenes, but if you are using a layer in insolation, you have to do the initialization yourself. For this, use the `initialize_once` method:\n", + "\n", + "```\n", + " def initialize_once(self, input_shapes, input_dtype, rng):\n", + " \"\"\"Initializes this layer and its sublayers recursively.\n", + "\n", + " This method is designed to initialize each layer instance once, even if the\n", + " same layer instance occurs in multiple places in the network. This enables\n", + " weight sharing to be implemented as layer sharing.\n", + "\n", + " ...\n", + "```\n", + "\n", + "### Example 3. tl.LayerNorm [1 input, 1 output, has parameters]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1829, + "status": "ok", + "timestamp": 1570493485782, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Ie7iyX91qAx2", + "outputId": "38353c1f-64b4-4d44-9068-02c6042f0c85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "layer_norm(x):\n", + "[[-1.414 -0.707 0. 0.707 1.414]\n", + " [-1.414 -0.707 0. 0.707 1.414]\n", + " [-1.414 -0.707 0. 0.707 1.414]]\n", + "\n", + "layer_norm.params:\n", + "(_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0) # Used below for layer initialization.\n", + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "layer_norm = tl.LayerNorm()\n", + "layer_norm.initialize_once(x.shape, x.dtype, prng_key)\n", + "y = layer_norm(x)\n", + "\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'layer_norm(x):\\n{}\\n')\n", + "print(template.format(x, y))\n", + "print('layer_norm.params:\\n{}'.format(layer_norm.params))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZWZUXEJAofH-" + }, + "source": [ + "## Layers combine into layers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d47gVdGV1vWw" + }, + "source": [ + "The Trax library authors encourage users, where possible, to build new layers as combinations of existing layers. The library provides a small set of _combinator_ layers for this: layer objects that make a list of layers behave as a single layer (a unit able to compute outputs from inputs, update parameters from gradients, and combine with yet more layers).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vC1ymG2j0iyp" + }, + "source": [ + "## Combine with Serial(...)\n", + "\n", + "The most common way to combine layers is serially, using the `Serial` class:\n", + "```\n", + "class Serial(base.Layer):\n", + " \"\"\"Combinator that applies layers serially (by function composition).\n", + "\n", + " A Serial combinator uses stack semantics to manage data for its sublayers.\n", + " Each sublayer sees only the inputs it needs and returns only the outputs it\n", + " has generated. The sublayers interact via the data stack. For instance, a\n", + " sublayer k, following sublayer j, gets called with the data stack in the\n", + " state left after layer j has applied. The Serial combinator then:\n", + "\n", + " - takes N_in items off the top of the stack (N_in = k.n_inputs) and calls\n", + " layer k, passing those items as arguments; and\n", + "\n", + " - takes layer k's N_out return values (N_out = k.n_outputs) and pushes\n", + " them onto the data stack.\n", + "\n", + " ...\n", + "```\n", + "As described above, the output of one layer is the input of the next, which amounts to function composition:\n", + "\n", + "```\n", + "# h(.) = g(f(.))\n", + "layer_h = Serial(\n", + " layer_f,\n", + " layer_g,\n", + ")\n", + "```\n", + "\n", + "### Example 4. y = layer_norm(relu(x)) [1 input, 1 output, has parameters]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "height": 170 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1190, + "status": "ok", + "timestamp": 1570493486986, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "dW5fpusjvjmh", + "outputId": "9facd363-1009-4459-d9a6-61aebbcbe0f1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "layer_block(x):\n", + "[[ 0. 0. 0. 0. 0. ]\n", + " [-0.75 -0.75 -0.75 0.5 1.75 ]\n", + " [-1.414 -0.707 0. 0.707 1.414]]\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0)\n", + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "layer_block = tl.Serial(\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + ")\n", + "layer_block.initialize_once(x.shape, x.dtype, prng_key)\n", + "y = layer_block(x)\n", + "\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'layer_block(x):\\n{}')\n", + "print(template.format(x, y,))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "bRtmN6ckQO1q" + }, + "source": [ + "And we can inspect the block as a whole, as if it were just another layer:\n", + "\n", + "### Example 5. Inspecting a Serial layer." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "height": 102 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 299, + "status": "ok", + "timestamp": 1570493487299, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "D6BpYddZQ1eu", + "outputId": "2dc36622-b5de-4eae-8a82-49b315349c05" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "layer_block:\n", + "Serial{in=1,out=1,sublayers=[Relu{in=1,out=1}, LayerNorm{in=1,out=1}]}\n", + "\n", + "layer_block.params:\n", + "[(), (_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))]\n" + ] + } + ], + "source": [ + "print('layer_block:\\n{}\\n'.format(layer_block))\n", + "\n", + "print('layer_block.params:\\n{}'.format(layer_block.params))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PqVNdoONcTp0" + }, + "source": [ + "## Combine with Parallel(...)\n", + "\n", + "The `Parallel` combinator arranges layers into separate computational channels, each with its own inputs/outputs and gradient flows:\n", + "```\n", + "class Parallel(base.Layer):\n", + " \"\"\"Combinator that applies a list of layers in parallel to its inputs.\n", + "\n", + " Layers in the list apply to successive spans of inputs, where the spans are\n", + " determined how many inputs each layer takes. The resulting output is the\n", + " (flattened) concatenation of the resepective layer outputs.\n", + "\n", + " For example, suppose one has three layers:\n", + "\n", + " - F: 1 input, 1 output\n", + " - G: 3 inputs, 1 output\n", + " - H: 2 inputs, 2 outputs (h1, h2)\n", + "\n", + " Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:\n", + "\n", + " - inputs: a, b, c, d, e, f\n", + " - outputs: F(a), G(b, c, d), h1, h2\n", + "```\n", + "\n", + "Separate (parallel) computation channels make sense when each channel can do its work (computing outputs from inputs) independent of the inputs and outputs of the others.\n", + "\n", + "As a simplistic example, consider writing a converter from three-digit octal (base 8) numerals to their corresponding values. For instance, to do conversions such as\n", + "```\n", + "123 (octal) = 1 * 8^2 + 2 * 8^1 + 3 * 8^0 = 83 (decimal)\n", + "345 (octal) = 3 * 8^2 + 4 * 8^1 + 6 * 8^0 = 229 (decimal)\n", + "567 (octal) = 5 * 8^2 + 6 * 8^1 + 7 * 8^0 = 375 (decimal)\n", + "```\n", + "the digits can first be converted independently, according to their place value (multiply by 64, multiply by 8, or multiply by 1). The following code runs the 64's-place digits ([1, 3, 5]) through one layer, the 8's-place digits ([2, 4, 6]) through a different layer, and the 1's-place digits ([3, 5, 7]) through yet a different layer. These three layers are combined in a Parallel layer:\n", + "\n", + "### Example 6. Processing octal digits in parallel." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1749, + "status": "ok", + "timestamp": 1570493489061, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "uQMqq3h_b2jQ", + "outputId": "842ca16a-8a64-4b1c-b997-aa09ec767f14" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputs:\n", + "(array([1, 3, 5]), array([2, 4, 6]), array([3, 5, 7]))\n", + "\n", + "octal_place_values(inputs):\n", + "(array([ 64., 192., 320.]), array([16., 32., 48.]), array([3., 5., 7.]))\n", + "\n", + "octal_place_values.n_inputs: 3\n", + "octal_place_values.n_outputs: 3\n", + "octal_place_values.sublayers: [MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]\n", + "octal_place_values.params: ((), (), ())\n", + "\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0)\n", + "place_64_digits = onp.array([1, 3, 5])\n", + "place_8_digits = onp.array([2, 4, 6])\n", + "place_1_digits = onp.array([3, 5, 7])\n", + "inputs = (place_64_digits, place_8_digits, place_1_digits)\n", + "input_shapes = [[3]] * 3\n", + "input_dtypes = [onp.int32] * 3\n", + "\n", + "# Create three simple layers, each for converting a different digit in base 8.\n", + "sixty_fours = tl.MulConstant(constant=64.0) # 8^2: 100 in base 8\n", + "eights = tl.MulConstant(constant=8.0) # 8^1: 10 in base 8\n", + "ones = tl.MulConstant(constant=1.0) # 8^0: 1 in base 8\n", + "\n", + "# Create a combined layer to convert digits to values (using big-endian base 8),\n", + "# initialize it, and apply it.\n", + "octal_place_values = tl.Parallel(sixty_fours, eights, ones)\n", + "octal_place_values.initialize_once(input_shapes, input_dtypes, prng_key)\n", + "outputs = octal_place_values(inputs)\n", + "\n", + "# Show inputs, outputs, and properties.\n", + "template = ('inputs:\\n{}\\n\\n'\n", + " 'octal_place_values(inputs):\\n{}\\n')\n", + "print(template.format(inputs, outputs))\n", + "show_layer_properties(octal_place_values, 'octal_place_values')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "q_xcWide3e5f" + }, + "source": [ + "To complete the example, the three outputs (values for the different digits) are combined by successive pairwise additions:\n", + "\n", + "### Example 6'. Combining outputs from upstream parallel layers." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "height": 275 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 968, + "status": "ok", + "timestamp": 1570493490044, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "ZDCkrvUp3u0-", + "outputId": "20823e0d-bb07-4503-82f0-d42eaffcc76e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputs:\n", + "(array([1, 3, 5]), array([2, 4, 6]), array([3, 5, 7]))\n", + "\n", + "octal_place_values(inputs):\n", + "(array([ 64., 192., 320.]), array([16., 32., 48.]), array([3., 5., 7.]))\n", + "\n", + "evaluate_octal(inputs):\n", + "[ 83. 229. 375.]\n", + "\n", + "evaluate_octal.n_inputs: 3\n", + "evaluate_octal.n_outputs: 1\n", + "evaluate_octal.sublayers: [Parallel{in=3,out=3,sublayers=[MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]}, Add{in=2,out=1}, Add{in=2,out=1}]\n", + "evaluate_octal.params: [(), (), ()]\n", + "\n" + ] + } + ], + "source": [ + "evaluate_octal = tl.Serial(\n", + " octal_place_values,\n", + " tl.Add(), # Adds the 64's-place values and the 8's-place values.\n", + " tl.Add(), # Adds the 1's-place values to the sums from the previous Add.\n", + ")\n", + "evaluate_octal.initialize_once(input_shapes, input_dtypes, prng_key)\n", + "y = evaluate_octal(inputs)\n", + "\n", + "template = ('inputs:\\n{}\\n\\n'\n", + " 'octal_place_values(inputs):\\n{}\\n\\n'\n", + " 'evaluate_octal(inputs):\\n{}\\n')\n", + "print(template.format(inputs, outputs, y))\n", + "show_layer_properties(evaluate_octal, 'evaluate_octal')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rwgiP0tK1H6p" + }, + "source": [ + "# Data Flows, Data Stack" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "llAH3cdE1UeU" + }, + "source": [ + "# Training and Using Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "65ite-671cTT" + }, + "source": [ + "# Defining Your Own Layer Classes" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook", + "kind": "private" + }, + "name": "A Conceptual, Practical Introduction to Trax Layers", + "provenance": [ + { + "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", + "timestamp": 1569980697572 + }, + { + "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", + "timestamp": 1563927451951 + } + ] + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/trax/layers/intro.ipynbE b/trax/layers/intro.ipynbE new file mode 100644 index 000000000..0658786f1 --- /dev/null +++ b/trax/layers/intro.ipynbE @@ -0,0 +1,834 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7yuytuIllsv1" + }, + "source": [ + "# A Conceptual, Practical Introduction to Trax Layers\n", + "\n", + "This notebook introduces the core concepts and programming components of the Trax library through a series of code samples and explanations. The topics covered in following sections are:\n", + " - **layers**: the basic building blocks and how to combine them into networks\n", + " - **data flows, data stack**: how the Trax runtime moves data through the layers\n", + " - **models**: how to train, evaluate, and run predictions with Trax models\n", + " - **new layer classes**: how to define and test your own Layer classes\n", + "\n", + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BIl27504La0G" + }, + "source": [ + "## General Setup\n", + "Execute the following few cells (once) before running any of the code samples in this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oILRLCWN_16u" + }, + "outputs": [], + "source": [ + "#@title\n", + "# Copyright 2018 Google LLC.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "\n", + "import numpy as onp\n", + "\n", + "\n", + "\n", + "# Import Trax\n", + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "cellView": "both", + "colab": { + "height": 51 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 383, + "status": "ok", + "timestamp": 1570168980195, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "vlGjGoGMTt-D", + "outputId": "6d2ecf3d-3eb8-48a7-ad12-ebefe83afaf1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/bin/sh: pip: command not found\n", + "/bin/sh: pip: command not found\n" + ] + } + ], + "source": [ + "#@title Run for installation.\n", + "\n", + "! pip install -q -U tensor2tensor\n", + "! pip install -q tensorflow\n", + "\n", + "from tensor2tensor.trax import backend\n", + "from tensor2tensor.trax import layers as tl" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bYWNWL9MJHv9" + }, + "outputs": [], + "source": [ + "onp.set_printoptions(precision=3) # Less visual noise in the numerical outputs.\n", + "\n", + "def show_layer_properties(layer_obj, layer_name):\n", + " template = ('{}.n_inputs: {}\\n'\n", + " '{}.n_outputs: {}\\n'\n", + " '{}.sublayers: {}\\n'\n", + " '{}.params: {}\\n')\n", + " print(template.format(layer_name, layer_obj.n_inputs,\n", + " layer_name, layer_obj.n_outputs,\n", + " layer_name, layer_obj.sublayers,\n", + " layer_name, layer_obj.params)) " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-LQ89rFFsEdk" + }, + "source": [ + "# Layers\n", + "\n", + "The Layer class represents Trax's concept of a layer, as summarized in the start of the class's docstring:\n", + "```\n", + "class Layer(object):\n", + " \"\"\"Base class for composable layers in a deep learning network.\n", + "\n", + " Layers are the basic building blocks for deep learning models. A Trax layer\n", + " computes a function from zero or more inputs to zero or more outputs,\n", + " optionally using trainable parameters (common) and non-parameter state (not\n", + " common). Authors of new layer subclasses typically override at most two\n", + " methods of the base `Layer` class:\n", + "\n", + " forward(inputs, params=(), state=(), **kwargs):\n", + " Computes this layer's output as part of a forward pass through the model.\n", + "\n", + " new_params_and_state(self, input_shape, input_dtype, rng):\n", + " Returns a (params, state) pair suitable for initializing this layer.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LyLVtdxorDPO" + }, + "source": [ + "## A layer computes a function." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ntZ4_eNQldzL" + }, + "source": [ + "A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects wrapping NumPy arrays.\n", + "\n", + "The simplest layers, those with no parameters, state or sublayers, can be used without initialization. You can think of them (and test them) like simple mathematical functions. For ease of testing and interactive exploration, layer\n", + "objects implement the `__call__ ` method, so you can call them directly on input data:\n", + "```\n", + "y = layer(x)\n", + "```\n", + "\n", + "Layers are also objects, so you can inspect their properties. For example:\n", + "```\n", + "print('Number of inputs required by this layer: {}'.format(layer.n_inputs))\n", + "```\n", + "\n", + "### Example 1. tl.Relu [1 input, 1 output]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1543, + "status": "ok", + "timestamp": 1570168982080, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "V09viOSEQvQe", + "outputId": "b7c1c085-3b54-4673-f284-99d6440f8a52" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "relu(x):\n", + "[[0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 1. 2.]\n", + " [3. 4. 5. 6. 7.]]\n", + "\n", + "number of inputs expected by this layer: 1\n", + "number of outputs promised by this layer: 1\n" + ] + } + ], + "source": [ + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "# Create a layer object (a Relu instance) and apply the layer to data x.\n", + "relu = tl.Relu()\n", + "y = relu(x)\n", + "\n", + "# Show input, output, and two layer properties.\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'relu(x):\\n{}\\n\\n'\n", + " 'number of inputs expected by this layer: {}\\n'\n", + " 'number of outputs promised by this layer: {}')\n", + "print(template.format(x, y, relu.n_inputs, relu.n_outputs))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7sYxIT8crFVE" + }, + "source": [ + "### Example 2. tl.Concatenate [2 inputs, 1 output]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "height": 442 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1558, + "status": "ok", + "timestamp": 1570168983657, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "LMPPNWXLoOZI", + "outputId": "24398ccb-9cda-4bdd-c0f0-4904c02a215e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x1:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "x2:\n", + "[[-70. -60. -50. -40. -30.]\n", + " [-20. -10. 0. 10. 20.]\n", + " [ 30. 40. 50. 60. 70.]]\n", + "\n", + "concat0([x1, x2]):\n", + "[[ -7. -6. -5. -4. -3.]\n", + " [ -2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]\n", + " [-70. -60. -50. -40. -30.]\n", + " [-20. -10. 0. 10. 20.]\n", + " [ 30. 40. 50. 60. 70.]]\n", + "\n", + "concat1([x1, x2]):\n", + "[[ -7. -6. -5. -4. -3. -70. -60. -50. -40. -30.]\n", + " [ -2. -1. 0. 1. 2. -20. -10. 0. 10. 20.]\n", + " [ 3. 4. 5. 6. 7. 30. 40. 50. 60. 70.]]\n", + "\n", + "concat0: Concatenate{in=2,out=1}\n", + "concat1: Concatenate{in=2,out=1}\n" + ] + } + ], + "source": [ + "x1 = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "x2 = 10 * x1\n", + "\n", + "concat0 = tl.Concatenate(axis=0)\n", + "concat1 = tl.Concatenate(axis=1)\n", + "\n", + "y0 = concat0([x1, x2])\n", + "y1 = concat1([x1, x2])\n", + "\n", + "template = ('x1:\\n{}\\n\\n'\n", + " 'x2:\\n{}\\n\\n'\n", + " 'concat0([x1, x2]):\\n{}\\n\\n'\n", + " 'concat1([x1, x2]):\\n{}\\n')\n", + "print(template.format(x1, x2, y0, y1))\n", + "\n", + "# Print abbreviated object representations (useful for debugging).\n", + "print('concat0: {}'.format(concat0))\n", + "print('concat1: {}'.format(concat1))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1oZv3R8bRMvF" + }, + "source": [ + "## Layers are trainable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3d64M7wLryji" + }, + "source": [ + "Most layer types are trainable: they include parameters that modify the computation of outputs from inputs, and they use back-progagated gradients to update those parameters.\n", + "\n", + "Before use, trainable layers must have their parameters initialized, typically using a PRNG (pseudo-random number generator) key for random number generation. Trax's model trainers take care of this behind the scenes, but if you are using a layer in insolation, you have to do the initialization yourself. For this, use the `initialize_once` method:\n", + "\n", + "```\n", + " def initialize_once(self, input_shapes, input_dtype, rng):\n", + " \"\"\"Initializes this layer and its sublayers recursively.\n", + "\n", + " This method is designed to initialize each layer instance once, even if the\n", + " same layer instance occurs in multiple places in the network. This enables\n", + " weight sharing to be implemented as layer sharing.\n", + "\n", + " ...\n", + "```\n", + "\n", + "### Example 3. tl.LayerNorm [1 input, 1 output, has parameters]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 2555, + "status": "ok", + "timestamp": 1570168986228, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Ie7iyX91qAx2", + "outputId": "3fe02659-481b-4912-c7eb-85eb01cfadd6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "layer_norm(x):\n", + "[[-1.414 -0.707 0. 0.707 1.414]\n", + " [-1.414 -0.707 0. 0.707 1.414]\n", + " [-1.414 -0.707 0. 0.707 1.414]]\n", + "\n", + "layer_norm.params:\n", + "(_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0) # Used below for layer initialization.\n", + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "layer_norm = tl.LayerNorm()\n", + "layer_norm.initialize_once(x.shape, x.dtype, prng_key)\n", + "y = layer_norm(x)\n", + "\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'layer_norm(x):\\n{}\\n')\n", + "print(template.format(x, y))\n", + "print('layer_norm.params:\\n{}'.format(layer_norm.params))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ZWZUXEJAofH-" + }, + "source": [ + "## Layers combine into layers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d47gVdGV1vWw" + }, + "source": [ + "The Trax library authors encourage users, where possible, to build new layers as combinations of existing layers. The library provides a small set of _combinator_ layers for this: layer objects that make a list of layers behave as a single layer (a unit able to compute outputs from inputs, update parameters from gradients, and combine with yet more layers).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vC1ymG2j0iyp" + }, + "source": [ + "## Combine with Serial(...)\n", + "\n", + "The most common way to combine layers is serially, using the `Serial` class:\n", + "```\n", + "class Serial(base.Layer):\n", + " \"\"\"Combinator that applies layers serially (by function composition).\n", + "\n", + " A Serial combinator uses stack semantics to manage data for its sublayers.\n", + " Each sublayer sees only the inputs it needs and returns only the outputs it\n", + " has generated. The sublayers interact via the data stack. For instance, a\n", + " sublayer k, following sublayer j, gets called with the data stack in the\n", + " state left after layer j has applied. The Serial combinator then:\n", + "\n", + " - takes N_in items off the top of the stack (N_in = k.n_inputs) and calls\n", + " layer k, passing those items as arguments; and\n", + "\n", + " - takes layer k's N_out return values (N_out = k.n_outputs) and pushes\n", + " them onto the data stack.\n", + "\n", + " ...\n", + "```\n", + "As described above, the output of one layer is the input of the next, which amounts to function composition:\n", + "\n", + "```\n", + "# h(.) = g(f(.))\n", + "layer_h = Serial(\n", + " layer_f,\n", + " layer_g,\n", + ")\n", + "```\n", + "\n", + "### Example 4. y = layer_norm(relu(x)) [1 input, 1 output, has parameters]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "height": 170 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1664, + "status": "ok", + "timestamp": 1570168987915, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "dW5fpusjvjmh", + "outputId": "207f6a59-b767-414f-a836-ec342157ef51" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x:\n", + "[[-7. -6. -5. -4. -3.]\n", + " [-2. -1. 0. 1. 2.]\n", + " [ 3. 4. 5. 6. 7.]]\n", + "\n", + "layer_block(x):\n", + "[[ 0. 0. 0. 0. 0. ]\n", + " [-0.75 -0.75 -0.75 0.5 1.75 ]\n", + " [-1.414 -0.707 0. 0.707 1.414]]\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0)\n", + "x = onp.arange(-7, 8).reshape(3, -1).astype(onp.float32)\n", + "\n", + "layer_block = tl.Serial(\n", + " tl.Relu(),\n", + " tl.LayerNorm(),\n", + ")\n", + "layer_block.initialize_once(x.shape, x.dtype, prng_key)\n", + "y = layer_block(x)\n", + "\n", + "template = ('x:\\n{}\\n\\n'\n", + " 'layer_block(x):\\n{}')\n", + "print(template.format(x, y,))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "bRtmN6ckQO1q" + }, + "source": [ + "And we can inspect the block as a whole, as if it were just another layer:\n", + "\n", + "### Example 5. Inspecting a Serial layer." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "height": 102 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 289, + "status": "ok", + "timestamp": 1570168988225, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "D6BpYddZQ1eu", + "outputId": "03a99733-cd84-4639-fb1f-8dfacebf5b07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "layer_block:\n", + "Serial{in=1,out=1,sublayers=[Relu{in=1,out=1}, LayerNorm{in=1,out=1}]}\n", + "\n", + "layer_block.params:\n", + "[(), (_FilledConstant([1., 1., 1., 1., 1.], dtype=float32), _FilledConstant([0., 0., 0., 0., 0.], dtype=float32))]\n" + ] + } + ], + "source": [ + "print('layer_block:\\n{}\\n'.format(layer_block))\n", + "\n", + "print('layer_block.params:\\n{}'.format(layer_block.params))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PqVNdoONcTp0" + }, + "source": [ + "## Combine with Parallel(...)\n", + "\n", + "The `Parallel` combinator arranges layers into separate computational channels, each with its own inputs/outputs and gradient flows:\n", + "```\n", + "class Parallel(base.Layer):\n", + " \"\"\"Combinator that applies a list of layers in parallel to its inputs.\n", + "\n", + " Layers in the list apply to successive spans of inputs, where the spans are\n", + " determined how many inputs each layer takes. The resulting output is the\n", + " (flattened) concatenation of the resepective layer outputs.\n", + "\n", + " For example, suppose one has three layers:\n", + "\n", + " - F: 1 input, 1 output\n", + " - G: 3 inputs, 1 output\n", + " - H: 2 inputs, 2 outputs (h1, h2)\n", + "\n", + " Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:\n", + "\n", + " - inputs: a, b, c, d, e, f\n", + " - outputs: F(a), G(b, c, d), h1, h2\n", + "```\n", + "\n", + "Separate (parallel) computation channels make sense when each channel can do its work (computing outputs from inputs) independent of the inputs and outputs of the others.\n", + "\n", + "As a simplistic example, consider writing a converter from three-digit octal (base 8) numerals to their corresponding values. For instance, to do conversions such as\n", + "```\n", + "123 (octal) = 1 * 8^2 + 2 * 8^1 + 3 * 8^0 = 83 (decimal)\n", + "345 (octal) = 3 * 8^2 + 4 * 8^1 + 6 * 8^0 = 229 (decimal)\n", + "567 (octal) = 5 * 8^2 + 6 * 8^1 + 7 * 8^0 = 375 (decimal)\n", + "```\n", + "the digits can first be converted independently, according to their place value (multiply by 64, multiply by 8, or multiply by 1). The following code runs the 64's-place digits ([1, 3, 5]) through one layer, the 8's-place digits ([2, 4, 6]) through a different layer, and the 1's-place digits ([3, 5, 7]) through yet a different layer. These three layers are combined in a Parallel layer:\n", + "\n", + "### Example 6. Processing octal digits in parallel." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 2224, + "status": "ok", + "timestamp": 1570168990465, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "uQMqq3h_b2jQ", + "outputId": "f3a43cae-e271-493a-f74a-31e1ff971bc1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputs:\n", + "(array([1, 3, 5]), array([2, 4, 6]), array([3, 5, 7]))\n", + "\n", + "octal_place_values(inputs):\n", + "(array([ 64., 192., 320.]), array([16., 32., 48.]), array([3., 5., 7.]))\n", + "\n", + "octal_place_values.n_inputs: 3\n", + "octal_place_values.n_outputs: 3\n", + "octal_place_values.sublayers: [MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]\n", + "octal_place_values.params: ((), (), ())\n", + "\n" + ] + } + ], + "source": [ + "prng_key = backend.random.get_prng(0)\n", + "place_64_digits = onp.array([1, 3, 5])\n", + "place_8_digits = onp.array([2, 4, 6])\n", + "place_1_digits = onp.array([3, 5, 7])\n", + "inputs = (place_64_digits, place_8_digits, place_1_digits)\n", + "input_shapes = [[3]] * 3\n", + "input_dtypes = [onp.int32] * 3\n", + "\n", + "# Create three simple layers, each for converting a different digit in base 8.\n", + "sixty_fours = tl.MulConstant(constant=64.0) # 8^2: 100 in base 8\n", + "eights = tl.MulConstant(constant=8.0) # 8^1: 10 in base 8\n", + "ones = tl.MulConstant(constant=1.0) # 8^0: 1 in base 8\n", + "\n", + "# Create a combined layer to convert digits to values (using big-endian base 8),\n", + "# initialize it, and apply it.\n", + "octal_place_values = tl.Parallel(sixty_fours, eights, ones)\n", + "octal_place_values.initialize_once(input_shapes, input_dtypes, prng_key)\n", + "outputs = octal_place_values(inputs)\n", + "\n", + "# Show inputs, outputs, and properties.\n", + "template = ('inputs:\\n{}\\n\\n'\n", + " 'octal_place_values(inputs):\\n{}\\n')\n", + "print(template.format(inputs, outputs))\n", + "show_layer_properties(octal_place_values, 'octal_place_values')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "q_xcWide3e5f" + }, + "source": [ + "To complete the example, the three outputs (values for the different digits) are combined by successive pairwise additions:\n", + "\n", + "### Example 6'. Combining outputs from upstream parallel layers." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "height": 275 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 2139, + "status": "ok", + "timestamp": 1570168992621, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "ZDCkrvUp3u0-", + "outputId": "696f21aa-5dad-4284-bfdd-ae637e2ce53f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputs:\n", + "(array([1, 3, 5]), array([2, 4, 6]), array([3, 5, 7]))\n", + "\n", + "octal_place_values(inputs):\n", + "(array([ 64., 192., 320.]), array([16., 32., 48.]), array([3., 5., 7.]))\n", + "\n", + "evaluate_octal(inputs):\n", + "[ 83. 229. 375.]\n", + "\n", + "evaluate_octal.n_inputs: 3\n", + "evaluate_octal.n_outputs: 1\n", + "evaluate_octal.sublayers: [Parallel{in=3,out=3,sublayers=[MulConstant{in=1,out=1}, MulConstant{in=1,out=1}, MulConstant{in=1,out=1}]}, Add{in=2,out=1}, Add{in=2,out=1}]\n", + "evaluate_octal.params: [(), (), ()]\n", + "\n" + ] + } + ], + "source": [ + "evaluate_octal = tl.Serial(\n", + " octal_place_values,\n", + " tl.Add(), # Adds the 64's-place values and the 8's-place values.\n", + " tl.Add(), # Adds the 1's-place values to the sums from the previous Add.\n", + ")\n", + "evaluate_octal.initialize_once(input_shapes, input_dtypes, prng_key)\n", + "y = evaluate_octal(inputs)\n", + "\n", + "template = ('inputs:\\n{}\\n\\n'\n", + " 'octal_place_values(inputs):\\n{}\\n\\n'\n", + " 'evaluate_octal(inputs):\\n{}\\n')\n", + "print(template.format(inputs, outputs, y))\n", + "show_layer_properties(evaluate_octal, 'evaluate_octal')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rwgiP0tK1H6p" + }, + "source": [ + "# Data Flows, Data Stack" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "llAH3cdE1UeU" + }, + "source": [ + "# Training and Using Models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "65ite-671cTT" + }, + "source": [ + "# Defining Your Own Layer Classes" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook", + "kind": "private" + }, + "name": "A Conceptual, Practical Introduction to Trax Layers", + "provenance": [ + { + "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt", + "timestamp": 1569980697572 + }, + { + "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5", + "timestamp": 1563927451951 + } + ] + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/trax/layers/metrics.py b/trax/layers/metrics.py new file mode 100644 index 000000000..5fff22fd9 --- /dev/null +++ b/trax/layers/metrics.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax metrics layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax.backend import numpy as np +from trax.layers import base +from trax.layers import combinators as cb +from trax.layers import core + + +@base.layer(n_inputs=2, n_outputs=1) +def CrossEntropy(x, axis=-1, **kw): + del kw + prediction, target = x + return np.sum(prediction * core.one_hot(target, prediction.shape[-1]), + axis=axis) + + +@base.layer(n_inputs=2, n_outputs=1) +def L2(x, axis=-1, **kw): + del kw + prediction, target = x + return np.sum((prediction - target)**2, axis=axis) + + +@base.layer(n_inputs=2, n_outputs=1) +def Accuracy(x, axis=-1, **kw): + del kw + prediction, target = x + predicted_class = np.argmax(prediction, axis=axis) + return np.equal(predicted_class, target) + + +@base.layer() +def WeightMask(target, mask_id=0, **kw): + del kw + if mask_id is None: + return np.ones_like(target) + return 1.0 - np.equal(target, mask_id).astype(np.float32) + + +@base.layer(n_inputs=2, n_outputs=1) +def WeightedMean(x, **kw): + del kw + metric, weights = x + weights_sum = np.sum(weights) + return np.sum(metric * weights) / weights_sum + + +def MaskedScalar(metric_layer, mask_id=None, has_weights=False): + """Metric as scalar compatible with Trax masking.""" + # Stack of (inputs, targets) --> (metric, weight-mask). + metric_and_mask = [ + cb.Parallel( + [], + cb.Dup() # Duplicate targets + ), + cb.Parallel( + metric_layer, # Metric: (inputs, targets) --> metric + WeightMask(mask_id=mask_id) # pylint: disable=no-value-for-parameter + ) + ] + if not has_weights: + # Take (metric, weight-mask) and return the weighted mean. + return cb.Serial([metric_and_mask, WeightedMean()]) # pylint: disable=no-value-for-parameter + return cb.Serial([ + metric_and_mask, + cb.Parallel( + [], + cb.Multiply() # Multiply given weights by mask_id weights + ), + WeightedMean() # pylint: disable=no-value-for-parameter + ]) + + +def CrossEntropyScalar(mask_id=None, has_weights=False): + """Cross-entropy as scalar compatible with Trax masking.""" + return MaskedScalar(CrossEntropy(), mask_id=mask_id, has_weights=has_weights) # pylint: disable=no-value-for-parameter + + +NegLogPerplexityScalar = CrossEntropyScalar + + +def CrossEntropyLossScalar(mask_id=None, has_weights=False): + """Cross-entropy loss as scalar compatible with Trax masking.""" + return cb.Serial( + CrossEntropyScalar(mask_id=mask_id, has_weights=has_weights), + core.MulConstant(constant=-1.0) + ) + + +def L2Scalar(mask_id=None, has_weights=False): + """L2 as scalar compatible with Trax masking.""" + return MaskedScalar(L2(), mask_id=mask_id, has_weights=has_weights) # pylint: disable=no-value-for-parameter + + +def L2LossScalar(mask_id=None, has_weights=False): + """L2 loss as scalar compatible with Trax masking.""" + return cb.Serial( + L2Scalar(mask_id=mask_id, has_weights=has_weights), + core.MulConstant(constant=-1.0) + ) + + +def AccuracyScalar(mask_id=None, has_weights=False): + """Accuracy as scalar compatible with Trax masking.""" + return MaskedScalar(Accuracy(), mask_id=mask_id, has_weights=has_weights) # pylint: disable=no-value-for-parameter diff --git a/trax/layers/metrics_test.py b/trax/layers/metrics_test.py new file mode 100644 index 000000000..6708a3e7e --- /dev/null +++ b/trax/layers/metrics_test.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +import numpy as onp +from trax import backend +from trax.layers import base +from trax.layers import metrics + + +class MetricsLayerTest(absltest.TestCase): + + def test_cross_entropy(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4)) + result_shape = base.check_shape_agreement( + metrics.CrossEntropy(), input_shape) + self.assertEqual(result_shape, (29, 4, 4)) + + def test_accuracy(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4)) + result_shape = base.check_shape_agreement( + metrics.Accuracy(), input_shape) + self.assertEqual(result_shape, (29, 4, 4)) + + def test_weight_mask(self): + input_shape = (29, 4, 4, 20) + result_shape = base.check_shape_agreement( + metrics.WeightMask(), input_shape) + self.assertEqual(result_shape, input_shape) + + def test_weighted_mean_shape(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4, 20)) + result_shape = base.check_shape_agreement( + metrics.WeightedMean(), input_shape) + self.assertEqual(result_shape, ()) + + def test_weighted_mean_semantics(self): + inputs = onp.array([1, 2, 3], dtype=onp.float32) + weights1 = onp.array([1, 1, 1], dtype=onp.float32) + layer = metrics.WeightedMean() + rng = backend.random.get_prng(0) + layer.initialize_once((inputs.shape, weights1.shape), + (inputs.dtype, weights1.dtype), rng) + mean1 = layer((inputs, weights1)) + onp.testing.assert_allclose(mean1, 2.0) + weights2 = onp.array([0, 0, 1], dtype=onp.float32) + mean2 = layer((inputs, weights2)) + onp.testing.assert_allclose(mean2, 3.0) + weights3 = onp.array([1, 0, 0], dtype=onp.float32) + mean3 = layer((inputs, weights3)) + onp.testing.assert_allclose(mean3, 1.0) + + def test_cross_entropy_scalar(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4)) + result_shape = base.check_shape_agreement( + metrics.CrossEntropyScalar(), input_shape) + self.assertEqual(result_shape, ()) + + def test_cross_entropy_loss_scalar(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4)) + result_shape = base.check_shape_agreement( + metrics.CrossEntropyLossScalar(), input_shape) + self.assertEqual(result_shape, ()) + + def test_accuracy_scalar(self): + input_shape = ((29, 4, 4, 20), (29, 4, 4)) + result_shape = base.check_shape_agreement( + metrics.AccuracyScalar(), input_shape) + self.assertEqual(result_shape, ()) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/normalization.py b/trax/layers/normalization.py new file mode 100644 index 000000000..fffc0f715 --- /dev/null +++ b/trax/layers/normalization.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax normalization layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax.backend import numpy as np +from trax.layers import base + + +class BatchNorm(base.Layer): + """Batch normalization.""" + + def __init__(self, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, + momentum=0.999, mode='train'): + super(BatchNorm, self).__init__() + self._axis = axis + self._epsilon = epsilon + self._center = center + self._scale = scale + self._momentum = momentum + self._mode = mode + + def new_params_and_state(self, input_shape, input_dtype, rng): + """Helper to initialize batch norm params.""" + del input_dtype, rng + axis = self._axis + axis = (axis,) if np.isscalar(axis) else axis + shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) + beta = np.zeros(shape, dtype='float32') if self._center else () + gamma = np.ones(shape, dtype='float32') if self._scale else () + def get_stats_axis(i, d): + if i in axis: + return 1 + else: + return d + stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) + running_mean = np.zeros(stats_shape, dtype=np.float32) + running_var = np.ones(stats_shape, dtype=np.float32) + n_batches = np.zeros((), dtype=np.int32) + params = (beta, gamma) + state = (running_mean, running_var, n_batches) + return params, state + + def _fast_mean_and_variance(self, x): + mean = np.mean(x, self._axis, keepdims=True) + # Fast but less numerically-stable variance calculation than np.var. + m1 = np.mean(x**2, self._axis, keepdims=True) + variance = m1 - mean**2 + return mean, variance + + def _exponential_smoothing(self, new, old): + smoothed_value = self._momentum * old + (1 - self._momentum) * new + return smoothed_value.astype(old.dtype) + + def _z_score(self, x, mean, variance): + mu = mean.astype(x.dtype) + sigma = np.sqrt(variance + self._epsilon).astype(x.dtype) + return (x - mu) / sigma + + def _beta_gamma_with_correct_axes(self, x, params): + # Expand the parameters to have the right axes. + beta, gamma = params + # TODO(phawkins): np.expand_dims should accept an axis tuple. + # (https://github.com/numpy/numpy/issues/12290) + ed = tuple(None if i in self._axis else slice(None) + for i in range(np.ndim(x))) + beta = beta[ed] + gamma = gamma[ed] + return beta, gamma + + def forward(self, x, params, state, **unused_kwargs): + """Computes batch normalization as part of a forward pass in the model.""" + + running_mean, running_var, n_batches = state + if self._mode == 'train': + n_batches += 1 + mean, var = self._fast_mean_and_variance(x) + running_mean = self._exponential_smoothing(mean, running_mean) + running_var = self._exponential_smoothing(var, running_var) + state = (running_mean, running_var, n_batches) + else: + mean = running_mean + var = running_var + + z = self._z_score(x, mean, var) + beta, gamma = self._beta_gamma_with_correct_axes(x, params) + + # Return the z rescaled by the parameters if requested. + if self._center and self._scale: + output = gamma * z + beta + elif self._center: + output = z + beta + elif self._scale: + output = gamma * z + else: + output = z + assert output.dtype == x.dtype, ('The dtype of the output (%s) of batch ' + 'norm is not the same as the input (%s). ' + 'Batch norm should not change the dtype' % + (output.dtype, x.dtype)) + return output, state + + +# Layer normalization. +def _layer_norm_params_and_state(input_shape, input_dtype, rng): + """Helper: create layer norm parameters.""" + del input_dtype, rng + features = input_shape[-1] + scale = np.ones(features) + bias = np.zeros(features) + params = (scale, bias) + return params, () + + +@base.layer(new_params_and_state_fn=_layer_norm_params_and_state) +def LayerNorm(x, params, epsilon=1e-6, **unused_kwargs): # pylint: disable=invalid-name + (scale, bias) = params + mean = np.mean(x, axis=-1, keepdims=True) + variance = np.mean((x - mean)**2, axis=-1, keepdims=True) + norm_inputs = (x - mean) / np.sqrt(variance + epsilon) + return norm_inputs * scale + bias diff --git a/trax/layers/normalization_test.py b/trax/layers/normalization_test.py new file mode 100644 index 000000000..17cdca955 --- /dev/null +++ b/trax/layers/normalization_test.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for normalization layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +import numpy as onp + +from trax import backend +from trax.backend import numpy as np +from trax.layers import base +from trax.layers import normalization + + +class NormalizationLayerTest(absltest.TestCase): + + def test_batch_norm_shape(self): + input_shape = (29, 5, 7, 20) + result_shape = base.check_shape_agreement( + normalization.BatchNorm(), input_shape) + self.assertEqual(result_shape, input_shape) + + def test_batch_norm(self): + input_shape = (2, 3, 4) + input_dtype = np.float32 + eps = 1e-5 + rng = backend.random.get_prng(0) + inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), + input_shape) + m1 = 11.5 # Mean of this random input. + v1 = 47.9167 # Variance of this random input. + layer = normalization.BatchNorm(axis=(0, 1, 2)) + _, _ = layer.initialize_once(input_shape, input_dtype, rng) + state = layer.state + onp.testing.assert_allclose(state[0], 0) + onp.testing.assert_allclose(state[1], 1) + self.assertEqual(state[2], 0) + out = layer(inp1) + state = layer.state + onp.testing.assert_allclose(state[0], m1 * 0.001) + onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) + self.assertEqual(state[2], 1) + onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), + rtol=1e-6) + + def test_layer_norm_shape(self): + input_shape = (29, 5, 7, 20) + result_shape = base.check_shape_agreement( + normalization.LayerNorm(), input_shape) + self.assertEqual(result_shape, input_shape) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/pooling.py b/trax/layers/pooling.py new file mode 100644 index 000000000..47dc605da --- /dev/null +++ b/trax/layers/pooling.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax pooling layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import backend +from trax.layers import base + + +@base.layer() +def MaxPool(x, params, pool_size=(2, 2), strides=None, padding='VALID', **kw): + del params, kw + return backend.max_pool(x, pool_size=pool_size, strides=strides, + padding=padding) + + +@base.layer() +def SumPool(x, params, pool_size=(2, 2), strides=None, padding='VALID', **kw): + del params, kw + return backend.sum_pool(x, pool_size=pool_size, strides=strides, + padding=padding) + + +@base.layer() +def AvgPool(x, params, pool_size=(2, 2), strides=None, padding='VALID', **kw): + del params, kw + return backend.avg_pool(x, pool_size=pool_size, strides=strides, + padding=padding) diff --git a/trax/layers/pooling_test.py b/trax/layers/pooling_test.py new file mode 100644 index 000000000..1ec8bfc33 --- /dev/null +++ b/trax/layers/pooling_test.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for conv layers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.layers import pooling + + +class PoolingLayerTest(absltest.TestCase): + + def test_avg_pool(self): + input_shape = (29, 4, 4, 20) + result_shape = base.check_shape_agreement( + pooling.AvgPool(pool_size=(2, 2), strides=(2, 2)), input_shape) + self.assertEqual(result_shape, (29, 2, 2, 20)) + + +if __name__ == "__main__": + absltest.main() diff --git a/trax/layers/reversible.py b/trax/layers/reversible.py new file mode 100644 index 000000000..c1e1456ac --- /dev/null +++ b/trax/layers/reversible.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of reversible layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import jax +from trax import backend +from trax.layers import base +from trax.layers import combinators as cb + + +class ReversibleLayer(base.Layer): + """Reversible Layer.""" + + def reverse(self, output, params=(), state=(), **kwargs): + """Reverse this layer: compute input given output.""" + raise NotImplementedError + + def reverse_and_grad(self, output, grad, params=(), state=(), **kwargs): + """Backward pass: computes the inverse of a layer and propagates gradients. + + While you may choose to only implement reverse, some layers implement this + function directly as computation may be shared between reversing and + computing gradients. + + Args: + output: Output activations; can be a (possibly nested) tuple. + grad: gradient signal (cotangent) computed based on subsequent layers. + The structure and shape must match the output. + params: layer parameters + state: start state + **kwargs: kwargs for the layer + + Returns: + A tuple (x, (x_grad, params_grad)), where x is the reconstructed input, + x_grad is the gradient signal for the input, and params_grad is the + gradient signal for the parameters. + """ + # Note: jax.vjp does not allow us to use **kwargs in the signature here. + def _do_forward(x, params): + return super(ReversibleLayer, self).forward( + x, params=params, state=state, **kwargs)[0] + + reconstructed_x = self.reverse(output, params, state, **kwargs) + _, vjpfun = jax.vjp(_do_forward, reconstructed_x, params) + x_params_grad = vjpfun(grad) + return reconstructed_x, x_params_grad + + @property + def has_backward(self): + return True + + def backward(self, inputs, output, ct, params, state, **kwargs): + del inputs + _, inputs_params_ct = self.reverse_and_grad(output, ct, params, state, + **kwargs) + return inputs_params_ct + + +class ReversibleSwap(ReversibleLayer, cb.Swap): + """Swap the first two element on the stack.""" + + def reverse(self, output, params=(), state=(), **kwargs): + # Swap is its own inverse, except that reverse doesn't return the state. + return self.forward(output, params=params, state=state, **kwargs)[0] + + +class ReversibleSerial(ReversibleLayer, cb.Serial): + """A reversible version of tl.Serial (requires reversible sub-layers).""" + + def __init__(self, *layers): + super(ReversibleSerial, self).__init__(*layers) + + # Note that sublayers has already been flattened to remove nested lists. + for i, layer in enumerate(self.sublayers): + if not isinstance(layer, ReversibleLayer): + raise ValueError( + 'Sub-layer {} of ReversibleSerial is not reversible: {}'.format( + i, layer)) + + def reverse(self, output, params=(), state=(), **kwargs): + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + + layer_val = output + for layer, p, s, rng in reversed(list(zip(self.sublayers, + params, state, rngs))): + layer_val = layer.reverse(layer_val, p, s, rng=rng, **kwargs) + + return layer_val + + def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + + layer_val = output + layer_ct = ct + params_ct = [] + for layer, p, s, rng in reversed(list(zip(self.sublayers, + params, state, rngs))): + layer_val, layer_ct = layer.reverse_and_grad( + layer_val, layer_ct, p, s, rng=rng, **kwargs) + layer_ct, p_ct = layer_ct + params_ct.insert(0, p_ct) + + return layer_val, (layer_ct, params_ct) diff --git a/trax/layers/reversible_test.py b/trax/layers/reversible_test.py new file mode 100644 index 000000000..f4dfd5229 --- /dev/null +++ b/trax/layers/reversible_test.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for reversible layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.layers import reversible + + +class ReversibleLayerTest(absltest.TestCase): + + def test_reversible_swap(self): + layer = reversible.ReversibleSwap() + input_shape = ((2, 3), (3, 3)) + final_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(final_shape, input_shape[::-1]) + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/layers/rnn.py b/trax/layers/rnn.py new file mode 100644 index 000000000..73c470fba --- /dev/null +++ b/trax/layers/rnn.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of common recurrent neural network cells (RNNs).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax.layers import combinators as cb +from trax.layers import convolution +from trax.layers import core + + +def GRUCell(n_units): + """Builds a traditional GRU cell with dense internal transformations. + + Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555 + + + Args: + n_units: Number of hidden units. + + Returns: + A Stax model representing a traditional GRU RNN cell. + """ + return GeneralGRUCell( + candidate_transform=lambda: core.Dense(n_units), + memory_transform_fn=None, + gate_nonlinearity=core.Sigmoid, + candidate_nonlinearity=core.Tanh) + + +def ConvGRUCell(n_units, kernel_size=(3, 3)): + """Builds a convolutional GRU. + + Paper: https://arxiv.org/abs/1511.06432. + + Args: + n_units: Number of hidden units + kernel_size: Kernel size for convolution + + Returns: + A Stax model representing a GRU cell with convolution transforms. + """ + + def BuildConv(): + return convolution.Conv( + filters=n_units, kernel_size=kernel_size, padding='SAME') + + return GeneralGRUCell( + candidate_transform=BuildConv, + memory_transform_fn=None, + gate_nonlinearity=core.Sigmoid, + candidate_nonlinearity=core.Tanh) + + +def GeneralGRUCell(candidate_transform, + memory_transform_fn=None, + gate_nonlinearity=core.Sigmoid, + candidate_nonlinearity=core.Tanh, + dropout_rate_c=0.1, + sigmoid_bias=0.5): + r"""Parametrized Gated Recurrent Unit (GRU) cell construction. + + GRU update equations: + $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ + $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ + $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ + $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ + + See combinators.Gate for details on the gating function. + + + Args: + candidate_transform: Transform to apply inside the Candidate branch. Applied + before nonlinearities. + memory_transform_fn: Optional transformation on the memory before gating. + gate_nonlinearity: Function to use as gate activation. Allows trying + alternatives to Sigmoid, such as HardSigmoid. + candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows + trying alternatives to traditional Tanh, such as HardTanh + dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works + best in a GRU when applied exclusively to this branch. + sigmoid_bias: Constant to add before sigmoid gates. Generally want to start + off with a positive bias. + + Returns: + A model representing a GRU cell with specified transforms. + """ + gate_block = [ # u_t + candidate_transform(), + core.AddConstant(constant=sigmoid_bias), + gate_nonlinearity(), + ] + reset_block = [ # r_t + candidate_transform(), + core.AddConstant(constant=sigmoid_bias), # Want bias to start positive. + gate_nonlinearity(), + ] + candidate_block = [ + cb.Dup(), + reset_block, + cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) + candidate_transform(), # Final projection + tanh to get Ct + candidate_nonlinearity(), # Candidate gate + + # Only apply dropout on the C gate. Paper reports 0.1 as a good default. + core.Dropout(rate=dropout_rate_c) + ] + memory_transform = memory_transform_fn() if memory_transform_fn else [] + return cb.Model( + cb.Dup(), cb.Dup(), + cb.Parallel(memory_transform, gate_block, candidate_block), + cb.Gate(), + ) diff --git a/trax/layers/rnn_test.py b/trax/layers/rnn_test.py new file mode 100644 index 000000000..61e775928 --- /dev/null +++ b/trax/layers/rnn_test.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for rnn layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.layers import rnn + + +class RnnLayerTest(absltest.TestCase): + + def _test_cell_runs(self, layer, input_shape, output_shape): + final_shape = base.check_shape_agreement(layer, input_shape) + self.assertEqual(output_shape, final_shape) + + def test_conv_gru_cell(self): + self._test_cell_runs( + rnn.ConvGRUCell(9, kernel_size=(3, 3)), + input_shape=(8, 1, 7, 9), + output_shape=(8, 1, 7, 9)) + + def test_gru_cell(self): + self._test_cell_runs( + rnn.GRUCell(9), input_shape=(8, 7, 9), output_shape=(8, 7, 9)) + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/learning_rate.py b/trax/learning_rate.py new file mode 100644 index 000000000..426576927 --- /dev/null +++ b/trax/learning_rate.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""trax learning rate schedules. + +The learning rate schedules here all have the signature: + lr: history -> (step -> {"learning_rate": lr}) + +That is, they are functions that take a trax.history.History and return a +function that takes a step and returns a dict with entry "learning_rate". +""" + +# TODO(pkozakowski): Revisit the decision to control nontrainable parameters +# using LR schedules, or at least rename the module. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +from absl import logging +import gin + +from trax import models as trax_models +from trax import utils +from trax.backend import numpy as np +from trax.backend import random as jax_random +from trax.rl import online_tune +from trax.rl import ppo + + +@gin.configurable(blacklist=["history"]) +def MultifactorSchedule(history=None, + factors="constant * linear_warmup * rsqrt_decay", + constant=0.1, + warmup_steps=400, + decay_factor=0.5, + steps_per_decay=20000): + """Factor-based learning rate schedule. + + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * decay_every: Every k steps decay the learning rate by decay_factor. + + Args: + history: the history of training and evaluation (History object). + factors: a string with factors separated by "*" that defines the schedule. + constant: float, the starting constant for the learning rate schedule. + warmup_steps: how many steps to warm up for in the warmup schedule. + decay_factor: The amount to decay the learning rate by. + steps_per_decay: How often to decay the learning rate. + + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + del history + + factors = [n.strip() for n in factors.split("*")] + + def learning_rate(step): # pylint: disable=invalid-name + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= constant + elif name == "linear_warmup": + ret *= np.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= np.sqrt(np.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= (decay_factor ** (step//steps_per_decay)) + else: + raise ValueError("Unknown factor %s." % name) + ret = np.asarray(ret, dtype=np.float32) + return {"learning_rate": ret} + + return learning_rate + + +@gin.configurable(blacklist=["history"]) +def EvalAdjustingSchedule(history, + constant=0.1, + steps_to_decrease=20, + improvement_margin=0.001, + decrease_rate=1.5, + history_mode="eval", + metric="metrics/accuracy"): + """Learning rate that decreases when eval metric stalls. + + If the chosen metric does not improve by improvement_margin for as many as + steps_to_decrease steps, then the constant gets decreased by decrease rate. + Finally, the MultifactorSchedule gets called with the adjusted constant. + + Args: + history: trax.history.History, the history of training and evaluation. + constant: float, the starting constant for the learning rate schedule. + steps_to_decrease: int, after how many steps without improvement + should we decrease the constant. + improvement_margin: how much we need to improve to consider the metric + improved. + decrease_rate: by what fraction to decrease (i.e. lr /= decrease_rate). + history_mode: str, which mode of the history to use. + metric: which evaluation metric to use for adjustments. + + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + metrics = history.get(history_mode, metric) + adjusted = constant + if len(metrics) < 2: + return MultifactorSchedule(history, constant=adjusted) + + steps_without_improvement = 0 + cur = metrics.pop()[1] # The most-recent value of the metric. + while len(metrics) > 1: + # The one-before value of metrics as .pop() removes one element each time. + prev = metrics.pop()[1] + if cur < prev * (1 + improvement_margin): + steps_without_improvement += 1 + else: + cur = prev + steps_without_improvement = 0 + if steps_without_improvement >= steps_to_decrease: + adjusted /= decrease_rate + cur = prev + steps_without_improvement = 0 + + return MultifactorSchedule(history, constant=adjusted) + + +@gin.configurable(blacklist=["history"]) +def PolicySchedule( + history, + observation_metrics=( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), + ), + include_controls_in_observation=False, + control_configs=( + # (name, start, (low, high), flip) + ("learning_rate", 1e-3, (1e-9, 10.0), False), + ), + observation_range=(0.0, 10.0), + action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), + policy_and_value_model=trax_models.FrameStackMLP, + policy_and_value_two_towers=False, + policy_and_value_vocab_size=None, + policy_dir=gin.REQUIRED, + temperature=1.0, +): + """Learning rate schedule controlled by a learned policy. + + Args: + history: the history of training and evaluation (History object). + observation_metrics: list of pairs (mode, metric), as in the History object. + include_controls_in_observation: bool, whether to include the controls in + observations. + control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. + observation_range: tuple (low, high), range to clip the metrics to. + action_multipliers: sequence of LR multipliers that policy actions + correspond to. + policy_and_value_model: Trax model to use as the policy. + policy_and_value_two_towers: bool, whether the action distribution and value + prediction is computed by separate model towers. + policy_and_value_vocab_size: vocabulary size of a policy and value network + operating on serialized representation. If None, use raw continuous + representation. + policy_dir: directory with the policy checkpoint. + temperature: temperature for sampling from the policy. + + Returns: + a function nontrainable_params(step): float -> {"name": float}, the + step-dependent schedule for nontrainable parameters. + """ + + # Turn the history into observations for the policy. If we don't have any, + # return the initial learning rate. + start_time = time.time() + observations = online_tune.history_to_observations( + history, observation_metrics, observation_range, + control_configs if include_controls_in_observation else None + ) + logging.vlog( + 1, "Building observations took %0.2f sec.", time.time() - start_time) + if observations.shape[0] == 0: + controls = { + name: start_value + for (name, start_value, _, _) in control_configs + } + return lambda _: controls + + assert policy_and_value_vocab_size is None, ( + "Serialized policies are not supported yet." + ) + # Build the policy network and load its parameters. + start_time = time.time() + net = ppo.policy_and_value_net( + n_controls=len(control_configs), + n_actions=len(action_multipliers), + vocab_size=policy_and_value_vocab_size, + bottom_layers_fn=policy_and_value_model, + two_towers=policy_and_value_two_towers, + ) + logging.vlog( + 1, "Building the policy network took %0.2f sec.", time.time() - start_time + ) + start_time = time.time() + # (opt_state, state, epoch, opt_step) + (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) + assert opt_state is not None, "Policy checkpoint not found." + (params, _) = opt_state + logging.vlog( + 1, "Restoring the policy parameters took %0.2f sec.", + time.time() - start_time + ) + + # Run the policy and sample an action. + seed = random.randint(0, 2**31 - 1) + rng = jax_random.get_prng(seed=seed) + start_time = time.time() + # ((log_probs, value_preds), state). We have no way to pass state to the next + # step, but that should be fine. + (log_probs, _) = ( + net(np.array([observations]), params=params, state=state, rng=rng)) + logging.vlog( + 1, "Running the policy took %0.2f sec.", time.time() - start_time + ) + # Sample from the action distribution for the last timestep. + assert log_probs.shape == ( + 1, len(control_configs) * observations.shape[0], len(action_multipliers) + ) + action = utils.gumbel_sample( + log_probs[0, -len(control_configs):, :] / temperature + ) + + # Get new controls. + controls = { + # name: value + control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension + control_config, control_action, history, action_multipliers + ) + for (control_action, control_config) in zip(action, control_configs) + } + return lambda _: controls diff --git a/trax/learning_rate_test.py b/trax/learning_rate_test.py new file mode 100644 index 000000000..82a9de234 --- /dev/null +++ b/trax/learning_rate_test.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.learning_rate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as onp +from tensorflow import test +from trax import history as trax_history +from trax import learning_rate +from trax.backend import numpy as np +from trax.backend import random as jax_random +from trax.models import atari_cnn +from trax.rl import online_tune +from trax.rl import ppo + + +class PolicyScheduleTest(test.TestCase): + + def _make_schedule( + self, + history, + control_configs, + observation_metrics=(("eval", "metrics/accuracy"),), + action_multipliers=(1.0,), + ): + policy_and_value_model = atari_cnn.FrameStackMLP + net = ppo.policy_and_value_net( + n_actions=len(action_multipliers), + n_controls=len(control_configs), + vocab_size=None, + bottom_layers_fn=policy_and_value_model, + two_towers=False, + ) + rng = jax_random.get_prng(seed=0) + obs_dim = len(observation_metrics) + (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng) + policy_dir = self.get_temp_dir() + # Optimizer slots should not be used for anything. + slots = None + opt_state = (params, slots) + ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0) + return learning_rate.PolicySchedule( + history, + observation_metrics=observation_metrics, + include_controls_in_observation=False, + action_multipliers=action_multipliers, + control_configs=control_configs, + policy_and_value_model=policy_and_value_model, + policy_and_value_two_towers=False, + policy_dir=policy_dir, + ) + + def test_returns_start_lr_when_there_are_no_metrics(self): + history = trax_history.History() + start_lr = 1e-3 + schedule = self._make_schedule( + history, + control_configs=(("learning_rate", start_lr, (1e-9, 1.0), False),), + ) + self.assertEqual(schedule(0)["learning_rate"], start_lr) + + def test_changes_lr_when_there_are_some_metrics(self): + history = trax_history.History() + history.append("eval", "metrics/accuracy", step=0, value=0.8) + history.append( + *online_tune.control_metric("learning_rate"), step=0, value=1e-4 + ) + schedule = self._make_schedule( + history, + control_configs=(("learning_rate", 1e-3, (1e-9, 1.0), False),), + observation_metrics=(("eval", "metrics/accuracy"),), + action_multipliers=(0.5, 2.0), + ) + new_lr = schedule(123)["learning_rate"] + self.assertTrue( + onp.allclose(new_lr, 5e-5) or onp.allclose(new_lr, 2e-4) + ) + + def test_works_with_multiple_controls(self): + history = trax_history.History() + history.append("eval", "metrics/accuracy", step=0, value=0.8) + history.append( + *online_tune.control_metric("learning_rate"), step=0, value=1e-4 + ) + history.append( + *online_tune.control_metric("weight_decay_rate"), step=0, value=1e-5 + ) + schedule = self._make_schedule( + history, + observation_metrics=(("eval", "metrics/accuracy"),), + control_configs=( + ("learning_rate", 1e-3, (1e-9, 1.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 1.0), False), + ), + action_multipliers=(1.0,), + ) + new_controls = schedule(123) + self.assertIn("learning_rate", new_controls) + self.assertIn("weight_decay_rate", new_controls) + + +if __name__ == "__main__": + test.main() diff --git a/trax/models/__init__.py b/trax/models/__init__.py new file mode 100644 index 000000000..3a5e594dc --- /dev/null +++ b/trax/models/__init__.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Models defined in trax.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin + +from trax.models import atari_cnn +from trax.models import mlp +from trax.models import neural_gpu +from trax.models import resnet +from trax.models import transformer +from trax.models.research import position_lookup_transformer +from trax.models.research import reformer + + +# Ginify +def model_configure(*args, **kwargs): + kwargs["module"] = "trax.models" + return gin.external_configurable(*args, **kwargs) + + +# pylint: disable=invalid-name +AtariCnn = model_configure(atari_cnn.AtariCnn) +FrameStackMLP = model_configure(atari_cnn.FrameStackMLP) +MLP = model_configure(mlp.MLP) +NeuralGPU = model_configure(neural_gpu.NeuralGPU) +PositionLookupTransformerLM = model_configure( + position_lookup_transformer.PositionLookupTransformerLM) +ReformerLM = model_configure(reformer.ReformerLM) +Resnet50 = model_configure(resnet.Resnet50) +Transformer = model_configure(transformer.Transformer) +TransformerDecoder = model_configure(transformer.TransformerDecoder) +TransformerEncoder = model_configure(transformer.TransformerEncoder) +TransformerLM = model_configure(transformer.TransformerLM) +WideResnet = model_configure(resnet.WideResnet) diff --git a/trax/models/atari_cnn.py b/trax/models/atari_cnn.py new file mode 100644 index 000000000..b3f98afc6 --- /dev/null +++ b/trax/models/atari_cnn.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple net for playing Atari games using PPO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import layers as tl + + +def FrameStack(n_frames): + """Stacks a fixed number of frames along the dimension 1.""" + # Input shape: (B, T, ..., C). + # Output shape: (B, T, ..., C * n_frames). + assert n_frames >= 1 + if n_frames == 1: + return () + return ( + # Make n_frames copies of the input sequence. + [tl.Dup()] * (n_frames - 1), + # Shift copies to the right by [0, .., n_frames - 1] frames. + tl.Parallel(*map(_shift_right, range(n_frames))), + # Concatenate along the channel dimension. + tl.Concatenate(n_items=n_frames, axis=-1), + ) + + +def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'): + """An Atari CNN.""" + del mode + + # TODO(jonni): Include link to paper? + # Input shape: (B, T, H, W, C) + # Output shape: (B, T, output_size) + return tl.Model( + tl.ToFloat(), + tl.Div(divisor=255.0), + + # Set up n_frames successive game frames, concatenated on the last axis. + FrameStack(n_frames=n_frames), # (B, T, H, W, 4C) + + tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), + tl.Relu(), + tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), + tl.Relu(), + tl.Flatten(n_axes_to_keep=2), # B, T and rest. + tl.Dense(output_size), + tl.Relu(), + ) + + +def FrameStackMLP(n_frames=4, hidden_sizes=(64,), output_size=64, + mode='train'): + """MLP operating on a fixed number of last frames.""" + del mode + + return tl.Model( + FrameStack(n_frames=n_frames), + [[tl.Dense(d_hidden), tl.Relu()] for d_hidden in hidden_sizes], + tl.Dense(output_size), + ) + + +def _shift_right(n): # pylint: disable=invalid-name + return [tl.ShiftRight()] * n diff --git a/trax/models/atari_cnn_test.py b/trax/models/atari_cnn_test.py new file mode 100644 index 000000000..1db635e12 --- /dev/null +++ b/trax/models/atari_cnn_test.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.models.atari_cnn.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import operator as op +import numpy as onp +from tensorflow import test +from trax.backend import random as jax_random +from trax.models import atari_cnn + + +class AtariCnnTest(test.TestCase): + + def test_computes(self): + rng_key = jax_random.get_prng(0) + hidden_size = (4, 4) + output_size = 6 + model = atari_cnn.AtariCnn( + hidden_sizes=hidden_size, output_size=output_size) + B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name + rng_key, key = jax_random.split(rng_key) + _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key) + x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( + B, T + 1, *OBS) + y = model(x) + self.assertEqual((B, T + 1, output_size), y.shape) + + +class FrameStackMLPTest(test.TestCase): + + def test_computes(self): + rng_key = jax_random.get_prng(0) + hidden_size = (4, 4) + output_size = 6 + model = atari_cnn.FrameStackMLP( + hidden_sizes=hidden_size, output_size=output_size) + B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name + rng_key, key = jax_random.split(rng_key) + _, _ = model.initialize_once((1, 1, OBS), onp.float32, key) + x = onp.arange(B * (T + 1) * OBS).reshape( + B, T + 1, OBS) + y = model(x) + self.assertEqual((B, T + 1, output_size), y.shape) + + +if __name__ == "__main__": + test.main() diff --git a/trax/models/mlp.py b/trax/models/mlp.py new file mode 100644 index 000000000..98c1eea4f --- /dev/null +++ b/trax/models/mlp.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MLP.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import layers as tl + + +def MLP(n_hidden_layers=2, + d_hidden=512, + activation_fn=tl.Relu, + n_output_classes=10, + mode="train"): + """A multi-layer feedforward (perceptron) network.""" + del mode + + return tl.Model( + tl.Flatten(), + [[tl.Dense(d_hidden), activation_fn()] for _ in range(n_hidden_layers)], + tl.Dense(n_output_classes), + tl.LogSoftmax(), + ) diff --git a/trax/models/mlp_test.py b/trax/models/mlp_test.py new file mode 100644 index 000000000..6897890d4 --- /dev/null +++ b/trax/models/mlp_test.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MLP.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax import backend +from trax import layers as tl +from trax.models import mlp + + +class MLPTest(absltest.TestCase): + + def test_mlp_forward_shape(self): + """Run the MLP model forward and check output shape.""" + input_shape = (3, 28, 28, 1) + model = mlp.MLP(d_hidden=32, n_output_classes=10) + final_shape = tl.check_shape_agreement(model, input_shape) + self.assertEqual((3, 10), final_shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/models/neural_gpu.py b/trax/models/neural_gpu.py new file mode 100644 index 000000000..d5cc78d69 --- /dev/null +++ b/trax/models/neural_gpu.py @@ -0,0 +1,82 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the improved Neural GPU (NGPU).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import layers as tl +from trax.backend import numpy as np + + +# TODO(ddohan): Combinator to add saturation costs to loss +def SaturationCost(x, limit=0.9): + return np.minimum(0, np.abs(x) - limit) + + +@tl.layer() +def DiagonalGate(x, params, **kwargs): + """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" + del params + del kwargs + # x : [batch, 1, length, depth] + x = np.pad( + x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) + depth = x.shape[-1] // 3 + assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, + x.shape) + xs = [ + x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], + x[:, :, 2:, 2 * depth:3 * depth] + ] + return np.concatenate(xs, axis=3) + + +def ConvDiagonalGRU(units, kernel_size=(3, 3)): + """Build convolutional GRU with diagonal gating as in ImprovedNGPU.""" + + def BuildConv(): + return tl.Conv(filters=units, kernel_size=kernel_size, padding='SAME') + + return tl.GeneralGRUCell( + candidate_transform=BuildConv, + memory_transform_fn=DiagonalGate, + gate_nonlinearity=tl.HardSigmoid, + candidate_nonlinearity=tl.HardTanh) + + +def NeuralGPU(d_feature=96, steps=16, vocab_size=2, mode='train'): + """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727. + + Args: + d_feature: Number of memory channels (dimensionality of feature embedding). + steps: Number of times depthwise recurrence steps. + vocab_size: Vocabulary size. + mode: Whether we are training or evaluating or doing inference. + + Returns: + A NeuralGPU Stax model. + """ + del mode + + core = ConvDiagonalGRU(units=d_feature) + return tl.Model( + tl.Embedding(d_feature=d_feature, vocab_size=vocab_size), + [core] * steps, + tl.Dense(vocab_size), + tl.LogSoftmax(), + ) diff --git a/trax/models/neural_gpu_test.py b/trax/models/neural_gpu_test.py new file mode 100644 index 000000000..6aa252112 --- /dev/null +++ b/trax/models/neural_gpu_test.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.models.neural_gpu.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax.layers import base +from trax.models import neural_gpu + + +class NeuralGPUTest(absltest.TestCase): + + def test_ngpu(self): + vocab_size = 2 + input_shape = [3, 5, 7] + model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=vocab_size) + final_shape = base.check_shape_agreement( + model, tuple(input_shape), integer_inputs=True) + self.assertEqual(tuple(input_shape + [vocab_size]), final_shape) + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/models/research/__init__.py b/trax/models/research/__init__.py new file mode 100644 index 000000000..7fa0b7f96 --- /dev/null +++ b/trax/models/research/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/trax/models/research/position_lookup_transformer.py b/trax/models/research/position_lookup_transformer.py new file mode 100644 index 000000000..34c910d21 --- /dev/null +++ b/trax/models/research/position_lookup_transformer.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Deep Lookups for Transformer Positions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as onp + +from trax import layers as tl +from trax.backend import numpy as np + + +# pylint: disable=g-complex-comprehension +# pylint: disable=no-value-for-parameter + +POS_VECTOR_SIZE = 32 +_ABSOLUTE_MAX_LEN = 10000 +_POSITIONS = onp.random.uniform(size=[_ABSOLUTE_MAX_LEN, POS_VECTOR_SIZE]) + + +def Dup2(): + """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...).""" + return [ # Stack is (a, b, ...) + tl.Parallel(tl.Dup(), tl.Dup()), # Stack is (a, a, b, b, ...) + tl.Parallel([], tl.Swap()) # Stack is (a, b, a, b, ...) + ] + + +@tl.layer() +def NewPositionalEncoding(x, positions=None, **kwargs): + """Implements new positional encoding.""" + del kwargs + x_length = np.shape(x)[1] + pos = np.array(positions)[np.newaxis, :x_length, :] + pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. + res = np.concatenate([x, pos], axis=2) + return res + + +@tl.layer(n_inputs=1, n_outputs=2) +def CutAtPosition(x, **unused_kwargs): + """Splits x into a pair (x[:position], position).""" + return tuple([x[:, :, :-POS_VECTOR_SIZE], x[:, :, -POS_VECTOR_SIZE:]]) + + +@tl.layer() +def MixHeadsPos(x, h=8, **unused_kwargs): + """Mix x = (x0, p) into x0_h1, p, x0_h2, p, ....""" + head_size = (x.shape[2] - POS_VECTOR_SIZE) // h + p = x[:, :, -POS_VECTOR_SIZE:] + res, idx = [], 0 + for _ in range(h): + res.append(x[:, :, idx:idx+head_size]) + res.append(p) + idx += head_size + return np.concatenate(res, axis=-1) + + +@tl.layer() +def CombineHeadsPos(x, h=8, **unused_kwargs): + """Mix x = (x0, p0, ..., xH, pH) into x0, ...., xH, p_combined. + + The positions are added as vectors. + + Args: + x: input vector, concatenated (x0, p0, ..., xH, pH). + h: number of heads. + + Returns: + the vector with combined positions. + """ + head_size = int((x.shape[2] / h) - POS_VECTOR_SIZE) + res, positions, idx = [], [], 0 + for _ in range(h): + res.append(x[:, :, idx:idx+head_size]) + idx += head_size + positions.append(x[:, :, idx:idx+POS_VECTOR_SIZE]) + idx += POS_VECTOR_SIZE + combined_position = sum(positions) + res.append(combined_position) + return np.concatenate(res, axis=-1) + + +@tl.layer() +def CopyHeadsPos(x, h=8, **unused_kwargs): + """Mix x = (x, p) into x_h1, p_h1, x_h2, p_h2, ....""" + head_size = (x.shape[2] - h*POS_VECTOR_SIZE) // h + p = x[:, :, -h*POS_VECTOR_SIZE:] + res, idx = [], 0 + for i in range(h): + res.append(x[:, :, idx:idx+head_size]) + res.append(p[:, :, i*POS_VECTOR_SIZE:(i+1)*POS_VECTOR_SIZE]) + idx += head_size + return np.concatenate(res, axis=-1) + + +def DeepFlatten(xs): + for x in xs: + if isinstance(x, (list, tuple)): + for y in DeepFlatten(x): + yield y + else: + yield x + + +def PreservePosition(layer): + """Execute layer without position but preserve it in parallel.""" + return tl.Serial( + CutAtPosition(), + layer, + tl.Concatenate(n_items=2) + ) + + +def ApplyAndQueryPositions(layer, pos): + """Execute layer without position and pos-layers on positions. + + This takes an embedding including position x = (emb, p), and + outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p) + where pos=[pos1...posn]. + + Args: + layer: layer to be executed without position information. + pos: list of layers to be applied to positions. + + Returns: + the result of this application. + """ + n_heads = len(pos) + return tl.Serial( + tl.Dup(), # (x, x) + CutAtPosition(), # (x_content, x_position, x) + tl.Parallel([], tl.Swap()), # (x_content, x, x_position) + [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)], + # Now the stack is x_content, (x, x_position) * n_heads. + tl.Parallel(*([layer] + pos)), + tl.Concatenate(n_items=n_heads + 1) + ) + + +@tl.layer() +def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): + """Query a table with a position vector.""" + if keys is None: + return x + k = np.array(keys) + v = np.array(values) + q = x + if binary: + q = np.concatenate([x, x], axis=-1) + return tl.DotProductAttention(q, k, v, None, None, None, None) + + +def LearnedQP(keys=None, values=None, binary=False): + """Get (query, pos), make learned weight of qeury and return with pos.""" + return tl.Parallel( + tl.Dense(1), + QueryPositionKV(keys=keys, values=values, binary=binary), + ) + + +@tl.layer(n_inputs=10, n_outputs=1) +def Softmax5Branches(x_list, n_branches=2, **unused_kwargs): + """Softmax xs. + + The input xs is a list of embeddings and weights of the form + w_1 e_1 .... w_n e_n (followed by optional rest that is preserved). + + Args: + x_list: the input weights and embeddings. + n_branches: what part of the list to use. + + Returns: + softmax(w) * e for the joint weights w and embeddings e. + """ + assert n_branches == 5 + softmax_activations = [x_list[2*i] for i in range(n_branches)] + max_sa = softmax_activations[0] + for x in softmax_activations: + max_sa = np.maximum(max_sa, x) + softmax_activations = [x - max_sa for x in softmax_activations] + softmax_activations = [np.exp(x) for x in softmax_activations] + sum_sa = sum(softmax_activations) + softmax_activations = [x / sum_sa for x in softmax_activations] + res = sum([x_list[2*i+1] * softmax_activations[i] for i in range(n_branches)]) + return res + + +def SumLearnedPick(positions): + """Get a pair (vec, pos) and pick new pos.""" + succ_keys = positions[:-1, :] + succ_values = positions[1:, :] + subtract_1_keys = positions[1:, :] + subtract_1_values = positions[:-1, :] + l = int(positions.shape[0]) // 2 + add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) + for i in range(l) for j in range(l)]) + add_values = np.array([positions[i + j, :] + for i in range(l) for j in range(l)]) + # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" + sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) + for j in range(l) for i in range(l)]) + sub_values = np.array([positions[max(i - j, 0), :] + for j in range(l) for i in range(l)]) + return tl.Serial( + Dup2(), Dup2(), Dup2(), Dup2(), + tl.Parallel( + LearnedQP(), + LearnedQP(keys=succ_keys, values=succ_values), + LearnedQP(keys=subtract_1_keys, values=subtract_1_values), + LearnedQP(keys=add_keys, values=add_values, binary=True), + LearnedQP(keys=sub_keys, values=sub_values, binary=True), + ), + Softmax5Branches(n_branches=5) + ) + + +def AttentionPosition(positions, d_model, n_heads=8, dropout=0.0, + mode='train'): + """Transformer-style multi-headed attention.""" + return tl.Serial( + tl.Dup(), + tl.Dup(), + tl.Parallel( + ApplyAndQueryPositions(tl.Dense(d_model), + pos=[SumLearnedPick(positions) + for _ in range(n_heads)]), + PreservePosition(tl.Dense(d_model)), + PreservePosition(tl.Dense(d_model)), + ), + tl.Parallel( + CopyHeadsPos(h=n_heads), + MixHeadsPos(h=n_heads), + MixHeadsPos(h=n_heads), + ), + tl.PureAttention(d_model=d_model, n_heads=n_heads, dropout=dropout, + mode=mode), + tl.Parallel([], tl.Drop()), # Drop the mask. + CombineHeadsPos(h=n_heads), + PreservePosition(tl.Dense(d_model)), + ) + + +def ResidualFeedForward(d_model, + d_ff, + dropout, + mode): + """Residual feed-forward layer with normalization at start.""" + stack = tl.Serial( + tl.LayerNorm(), + tl.Dense(d_ff), + tl.Relu(), + tl.Dropout(rate=dropout, mode=mode), + tl.Dense(d_model), + tl.Dropout(rate=dropout, mode=mode) + ) + return tl.Residual(PreservePosition(stack)) + + +def DecoderLayer(positions, + d_model, + d_ff, + n_heads, + dropout, + mode): + """Transformer decoder layer. + + Args: + positions: random vectors for positions + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + return [ + tl.Residual( # Self-attention block. + PreservePosition(tl.LayerNorm()), + tl.Dup(), + tl.Parallel([], # activation for (q, k, v) + tl.CausalMask(axis=-2)), # attention mask + AttentionPosition(positions, d_model, n_heads=n_heads, + dropout=dropout, mode=mode), + PreservePosition(tl.Dropout(rate=dropout, mode=mode)) + ), + ResidualFeedForward(d_model, d_ff, dropout, mode=mode) + ] + + +def PositionLookupTransformerLM(vocab_size=128, + d_model=256, + d_ff=512, + n_layers=3, + n_heads=4, + dropout=0.1, + max_len=100, + mode='train'): + """Transformer language model (only uses the decoder part of Transformer). + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: maximal length + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + positions = _POSITIONS[:max_len, :] + return tl.Serial( + tl.ShiftRight(), + tl.Embedding(d_model, vocab_size), + tl.Dropout(rate=dropout, mode=mode), + NewPositionalEncoding(positions=positions), + [DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode) + for _ in range(n_layers)], + PreservePosition(tl.LayerNorm()), + tl.Dense(vocab_size), + tl.LogSoftmax() + ) diff --git a/trax/models/research/reformer.py b/trax/models/research/reformer.py new file mode 100644 index 000000000..9c29eb1a2 --- /dev/null +++ b/trax/models/research/reformer.py @@ -0,0 +1,531 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer Models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import jax + +from trax import backend +from trax import layers as tl +from trax.layers.combinators import _pop_rng_and_split + + +# Layers are always CamelCase, but functions in general are snake_case +# pylint: disable=invalid-name + + +class Map(tl.Layer): + """Combinator for applying a layer to a list or tuple.""" + + def __init__(self, layer, n_sections=1, check_shapes=True): + """Initialize the combinator. + + Args: + layer: a layer to apply to each element. + n_sections: how many sections to map to (default: 1). + check_shapes: whether to check that shapes are identical (default: true). + + Returns: + A new layer representing mapping layer to all elements of the input. + """ + super(Map, self).__init__(n_inputs=n_sections, n_outputs=n_sections) + if layer is None or isinstance(layer, (list, tuple)): + layer = tl.Serial(layer) + self._layer = layer + # Generally a Map should be applied to lists where all elements have + # the same shape -- because self._layer will only be initialized once + # and it could have different parameters for different shapes. But there + # are valid cases -- e.g., when self._layer has no parameters -- where we + # can apply Map to different shapes -- set check_shapes=False in such cases. + self._check_shapes = check_shapes + self._n_sections = n_sections + + def forward(self, inputs, params=(), state=(), **kwargs): + rngs = _pop_rng_and_split(kwargs, len(inputs)) + results = [self._layer(x, params=params, state=state, rng=r, **kwargs) + for x, r in zip(inputs, rngs)] + # TODO(kitaev): think about how to merge state across copies in the map. + return tuple(results), self._layer.state + + def new_params_and_state(self, input_shape, input_dtype, rng): + first_shape = input_shape[0] + if self._check_shapes: + for shape in input_shape: + if shape != first_shape: + raise ValueError('Map layer can only be applied to list of elements ' + 'with the same shapes. Shapes: %s' % str(shape)) + return self._layer.initialize_once(first_shape, input_dtype[0], rng) + + @tl.Layer.params.setter + def params(self, params): + self._params = params + assert len(params) == 1 + self._layer.params = params[0] + + @tl.Layer.state.setter + def state(self, state): + self._state = state + assert len(state) == 1 + self._layer.state = state[0] + + +@tl.layer() +def BroadcastedDropout(x, params, rate=0.0, mode='train', broadcast_dims=(-2,), + rng=None, **kwargs): + """Dropout, with broadcasting to save memory.""" + del params, kwargs + if rng is None: + raise ValueError('BroadcastedDropout requires rng kwarg.') + if rate >= 1.0: + raise ValueError('Dropout rate (%f) must be lower than 1.' % rate) + if mode == 'train' and rate > 0.0: + noise_shape = list(x.shape) + for dim in broadcast_dims: + noise_shape[dim] = 1 + keep_prob = jax.lax.tie_in(rng, 1.0 - rate) + keep = backend.random.bernoulli(rng, keep_prob, tuple(noise_shape)) + multiplier = keep.astype(x.dtype) / jax.lax.tie_in(keep, keep_prob) + return x * multiplier + else: + return x + + +def FeedForward(d_model, d_ff, dropout, mode): + """Feed-forward block with layer normalization at start.""" + return [ + tl.LayerNorm(), + tl.Dense(d_ff), + BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter + tl.Relu(), + tl.Dense(d_model), + BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter + ] + + +class SplitForOutput(tl.ReversibleLayer): + """Splits activations into sections (for use right before the output layer). + + After the reversible portion of the network, there is a final output portion + that's non-reversible (which at minimum includes normalization, output + projection, and log-softmax). The output portion needs to operate on chunks + of the sequence to avoid running out of memory for large vocabulary sizes. + + This layer concatenates the two subparts of the activations along the feature + dimension, and then splits into chunks along the time dimension. We implement + it is a subclass of tl.ReversibleLayer because we want to ensure that multiple + copies of the activations don't exist simultaneously except in the middle of a + memory copy operation. + """ + + def __init__(self, n_sections=2, axis=-2): + super(SplitForOutput, self).__init__(n_inputs=2, n_outputs=n_sections) + self._n_sections = n_sections + self._axis = axis + + def forward(self, inputs, params=(), state=(), **kwargs): + del params, kwargs + x1, x2 = inputs + + x1_split = backend.numpy.split(x1, self._n_sections, self._axis) + x2_split = backend.numpy.split(x2, self._n_sections, self._axis) + + res = [backend.numpy.concatenate(ys, -1) for ys in zip(x1_split, x2_split)] + return tuple(res), state + + def reverse(self, output, params=(), state=(), **kwargs): + del params, kwargs + + x1_split = [] + x2_split = [] + for y in output: + y1, y2 = backend.numpy.split(y, 2, -1) + x1_split.append(y1) + x2_split.append(y2) + + x1 = backend.numpy.concatenate(x1_split, self._axis) + x2 = backend.numpy.concatenate(x2_split, self._axis) + + return (x1, x2) + + def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): + del params, kwargs + return self.reverse(output), (self.reverse(ct), ()) + + +@tl.layer() +def Chunk(x, params, n_sections=2, **kwargs): + del params, kwargs + assert x.shape[1] % n_sections == 0 + return backend.numpy.reshape(x, ( + x.shape[0] * n_sections, + x.shape[1] // n_sections, + ) + x.shape[2:]) + + +@tl.layer() +def Unchunk(x, params, n_sections=2, **kwargs): + del params, kwargs + assert x.shape[0] % n_sections == 0 + return backend.numpy.reshape(x, ( + x.shape[0] // n_sections, + x.shape[1] * n_sections, + ) + x.shape[2:]) + + +class ReversibleHalfResidual(tl.ReversibleLayer, tl.Serial): + """Half of a RevNet-style residual (only updates part of the hidden state).""" + + def __init__(self, residual_layers): + self.compute_residual = tl.Serial([ + # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) + tl.Parallel([], tl.Dup()), + tl.Swap(), + tl.Parallel(residual_layers, [], []), + ]) + + layers = [ + self.compute_residual, + tl.Parallel(tl.Add(), []) + ] + super(ReversibleHalfResidual, self).__init__(layers) + + self.subtract_top = tl.Parallel(tl.SubtractTop(), []) + self.reverse_layers = [self.compute_residual, self.subtract_top] + + def reverse(self, output, params=(), state=(), **kwargs): + reconstructed_x = output + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + # Note that self.sublayers aligns exactly with self.reverse_layers in + # terms of parameter and rng usage, so no re-ordering is required. + for layer, p, s, rng in zip(self.reverse_layers, params, state, rngs): + reconstructed_x = layer(reconstructed_x, params=p, state=s, rng=rng, + **kwargs) + return reconstructed_x + + def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + + def call_compute_residual(x, params): + res = self.compute_residual(x, params=params, state=state[0], rng=rngs[0], + **kwargs) + return res + + assert len(ct) == 2 + ct = ((ct[0], ct[0], ct[1])) + + stack_with_residual, vjpfun = jax.vjp( + call_compute_residual, output, params[0]) + reconstructed_x = self.subtract_top( + stack_with_residual, params=params[-1], state=state[-1], rng=rngs[-1], + **kwargs) + + x_ct, residual_params_ct = vjpfun(ct) + assert not jax.tree_util.tree_leaves(params[-1]) + add_top_params_ct = params[-1] + return reconstructed_x, (x_ct, [residual_params_ct, add_top_params_ct]) + + +class ApplyAttentionWrapper(tl.Parallel): + """Like tl.Parallel(attention, [], []) but implements forward_and_backward.""" + + def __init__(self, attention): + assert hasattr(attention, 'forward_and_backward') + super(ApplyAttentionWrapper, self).__init__(attention, [], []) + self.attention = attention + + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): + # Simultaneous forward pass and backprop through the attention mechanism. + qkv = inputs[:3] + passthrough = inputs[3:] + out_ct = ct[0] + passthrough_ct = ct[1:] + if rng is not None: + # Adjust RNG to match the forward pass. + rng = backend.random.split(rng, self._n_layers)[0] + + out, qkv_ct = self.attention.forward_and_backward( + qkv, out_ct, rng=rng, **kwargs) + return (out,) + passthrough, qkv_ct + passthrough_ct + + +class ReversibleAttentionHalfResidual(tl.ReversibleLayer, tl.Serial): + """Half of a RevNet-style residual that performs attention. + + If inputs are (x1, x2), then outputs are (x1 + z, x2) where: + z = post_attention(attention(pre_attention(x1))) + + Other than an efficiency optimization, this layer is equivalent to + ReversibleHalfResidual([pre_attention, attention, post_attention]). + + The post_attention layers must be linear in their input (typically they will + consists of reshaping and dense linear layers), which allows the following + optimization. We can back-propagate the gradient signal from the output of + ReversibleAttentionHalfResidual to the output of the "attention" portion based + only on the network parameters. Then, attention.forward_and_backward can be + used to recover the output of the "attention" portion while simultaneously + performing the backward pass, which allows shared computation between the two + directions. + """ + + def __init__(self, pre_attention, attention, post_attention): + self.pre_attention = tl.Serial([ + # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) + tl.Parallel([], tl.Dup()), + tl.Swap(), + tl.Parallel(pre_attention, [], []), + ]) + assert hasattr(attention, 'forward_and_backward') + self.attention = ApplyAttentionWrapper(attention) + self.post_attention = tl.Parallel(post_attention, [], []) + + layers = [ + self.pre_attention, + self.attention, + self.post_attention, + tl.Parallel(tl.Add(), []), + ] + super(ReversibleAttentionHalfResidual, self).__init__(layers) + + self.subtract_top = tl.Parallel(tl.SubtractTop(), []) + self.reverse_layers = [ + self.pre_attention, + self.attention, + self.post_attention, + self.subtract_top, + ] + + def reverse(self, output, params=(), state=(), **kwargs): + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + + reconstructed_x = output + # Note that self.sublayers aligns exactly with self.reverse_layers in + # terms of parameter and rng usage, so no re-ordering is required. + for layer, p, s, rng in zip(self.reverse_layers, params, state, rngs): + reconstructed_x = layer.reverse(reconstructed_x, params=p, state=s, + rng=rng, **kwargs) + return reconstructed_x + + def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): + rng = kwargs.pop('rng', None) + rngs = (None,) * self._n_layers + if rng is not None: + rngs = backend.random.split(rng, self._n_layers) + + # Forward pass through self.pre_attention, while preparing for + # later backprop. + def call_pre_attention(x, params): + res = self.pre_attention(x, params=params, state=state[0], rng=rngs[0], + **kwargs) + return res + stack, pre_attention_vjpfun = jax.vjp(call_pre_attention, output, params[0]) + + # Backprop through adding the residual + assert len(ct) == 2 + ct = saved_ct = (ct[0], ct[0], ct[1]) + + # Backprop through self.post_attention with respect to the inputs only + def call_post_attention(x): + res = self.post_attention(x, params=params[2], state=state[2], + rng=rngs[2], **kwargs) + return res + # Note: these are *not* the actual inputs to self.post_attention. + # If self.post_attention is not linear, we will get incorrect gradients. + dummy_inputs = (stack[-3], stack[-2], stack[-1]) + _, post_attention_vjpfun = jax.vjp(call_post_attention, dummy_inputs) + (ct,) = post_attention_vjpfun(ct) + + # Simultaneous forward pass and backprop through the attention mechanism + stack, ct = self.attention.forward_and_backward(stack, ct, rng=rngs[1], + **kwargs) + assert not jax.tree_util.tree_leaves(params[1]) + attention_params_ct = params[1] # This is valid when params is empty. + + # Backprop through self.pre_attention + x_ct, pre_attention_params_ct = pre_attention_vjpfun(ct) + + # Forward pass for self.post_attention, and backprop with respect to the + # parameters only + def call_post_attention2(params): + res = self.post_attention(stack, params=params, state=state[2], + rng=rngs[2], **kwargs) + return res + stack, post_attention_vjpfun = jax.vjp(call_post_attention2, params[2]) + (post_attention_params_ct,) = post_attention_vjpfun(saved_ct) + + # Forward pass through subtracting the residual + reconstructed_x = self.subtract_top( + stack, params=params[-1], state=state[-1], rng=rngs[-1], **kwargs) + + assert not jax.tree_util.tree_leaves(params[-1]) + add_top_params_ct = params[-1] + params_ct = [ + pre_attention_params_ct, + attention_params_ct, + post_attention_params_ct, + add_top_params_ct, + ] + + return reconstructed_x, (x_ct, params_ct) + + +def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, + n_heads, n_attention_chunks, attention_type, + dropout, share_qk, mode): + """Reversible transformer decoder layer. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_heads: int: number of attention heads + n_attention_chunks: int: number of chunks for attention + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + share_qk: string, whether to share queries and keys + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + if share_qk: + pre_attention = [ + Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter + tl.LayerNorm(), + tl.Dup(), + tl.Parallel( + tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), + ), + tl.Dup(), + ] + else: + pre_attention = [ + Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter + tl.LayerNorm(), + tl.Dup(), tl.Dup(), + tl.Parallel( + tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), + tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), + ), + ] + + attention = attention_type(mode=mode) + + # ReversibleAttentionHalfResidual requires that post_attention be linear in + # its input (so the backward pass can be computed without knowing the input) + post_attention = [ + tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), + Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter + BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter + ] + + feed_forward = [ + FeedForward(d_model, d_ff, dropout, mode=mode), + ] + return [ + ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), + tl.ReversibleSwap(), + ReversibleHalfResidual(feed_forward), + tl.ReversibleSwap(), + ] + + +def ReformerLM(vocab_size, + d_model=512, + d_ff=2048, + d_attention_key=64, + d_attention_value=64, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + n_chunks=0, + n_attention_chunks=1, + attention_type=tl.DotProductCausalAttention, + share_qk=False, + mode='train'): + """Reversible transformer language model (only uses a decoder, no encoder). + + Args: + vocab_size: int: vocab size + d_model: int: depth of *each half* of the two-part features + d_ff: int: depth of feed-forward layer + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + n_layers: int: number of decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + n_chunks: int: number of chunks (must match input pipeline) + n_attention_chunks: int: number of chunks for attention + attention_type: class: attention class to use, such as DotProductAttention. + share_qk: bool, whether to share queries and keys. + mode: str: 'train' or 'eval' + + Returns: + the layer. + """ + if n_chunks == 0: + n_chunks = 1 + concatenate_input_chunks = [] + concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2) + else: + concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) + concatenate_output_chunks = [] + + positional_embedder = [ + tl.Embedding(d_model, vocab_size), + BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter + tl.PositionalEncoding(max_len=max_len), + ] + return tl.Model( + concatenate_input_chunks, + tl.ShiftRight(), + positional_embedder, + tl.Dup(), + tl.ReversibleSerial([ + # pylint: disable=g-complex-comprehension + DecoderBlock(d_model, d_ff, + d_attention_key, d_attention_value, n_heads, + n_attention_chunks, attention_type, + dropout, share_qk, mode) + for _ in range(n_layers) + ] + [ + SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter + ]), + Map([ + # TODO(kitaev): Test whether dropout should go before or after the + # LayerNorm, and whether dropout broadcasting is needed here. + tl.LayerNorm(), + BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter + tl.Dense(vocab_size), + tl.LogSoftmax(), + ], n_sections=n_chunks), + concatenate_output_chunks, + ) diff --git a/trax/models/research/reformer_test.py b/trax/models/research/reformer_test.py new file mode 100644 index 000000000..604fc6ce1 --- /dev/null +++ b/trax/models/research/reformer_test.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer-Revnet models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as onp + +from trax import backend +from trax import layers as tl +from trax.backend import numpy as np +from trax.models.research import reformer + + +class PoisonOnRNGMismatchAttention(tl.BaseCausalAttention): + """Fills gradients with NaNs if reverse rng does not match forward rng.""" + + # pylint: disable=protected-access + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use forward_and_backward.') + + if ct is not None and tl.Layer._STASH_OUT is not None: + recovered_rng = tl.Layer._STASH_OUT.pop(self) + is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1]) + is_same = is_same.astype(np.float32) + # Divides by zero if rngs are not the same, which results in NaNs. + inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same) + + def _do_forward(x): # pylint: disable=invalid-name + res, _ = self.forward(x, rng=rng, **kwargs) + return res + output, vjpfun = jax.vjp(_do_forward, inputs) + return output, vjpfun(ct)[0] + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + if tl.Layer._STASH_IN is not None: + tl.Layer._STASH_IN[self] = rng + return inputs[2], state + # pylint: enable=protected-access + + +class ReformerTest(parameterized.TestCase): + + def test_reformer_lm_forward_shape(self): + """Run the ReformerLM forward and check output shape.""" + vocab_size = 16 + input_shape = ((1, 8), (1, 8)) + model = reformer.ReformerLM( + vocab_size, d_model=32, d_ff=64, + d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, + max_len=16, n_chunks=2, n_attention_chunks=1) + final_shape = tl.check_shape_agreement( + model, tuple(input_shape), integer_inputs=True) + self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape) + + def test_reformer_rng_consistency(self): + with backend.use_backend('jax'): + vocab_size = 16 + batch_size = 1 + input_shape = ((batch_size, 8), (batch_size, 8)) + model = reformer.ReformerLM( + vocab_size, d_model=32, d_ff=64, + d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, + max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', + attention_type=PoisonOnRNGMismatchAttention) + + rng = backend.random.get_prng(0) + params, state = model.initialize_once( + input_shape, (np.int32, np.int32), rng) + + def dummy_loss_fn(params): + inputs = (np.zeros(input_shape[0], dtype=np.int32),) * 2 + output = model(inputs, params=params, state=state, rng=rng) + dummy_loss = backend.numpy.sum(output[0]) + return dummy_loss + + grad_fn = backend.grad(dummy_loss_fn) + grads = grad_fn(params) + # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. + for grad in jax.tree_util.tree_leaves(grads): + assert onp.all(onp.isfinite(grad)) + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/models/resnet.py b/trax/models/resnet.py new file mode 100644 index 000000000..97764e996 --- /dev/null +++ b/trax/models/resnet.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax import layers as tl + + +def ConvBlock(kernel_size, filters, strides, mode='train'): + """ResNet convolutional striding block.""" + # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. + ks = kernel_size + filters1, filters2, filters3 = filters + main = [ + tl.Conv(filters1, (1, 1), strides), + tl.BatchNorm(mode=mode), + tl.Relu(), + tl.Conv(filters2, (ks, ks), padding='SAME'), + tl.BatchNorm(mode=mode), + tl.Relu(), + tl.Conv(filters3, (1, 1)), + tl.BatchNorm(mode=mode), + ] + shortcut = [ + tl.Conv(filters3, (1, 1), strides), + tl.BatchNorm(mode=mode), + ] + return [ + tl.Residual(main, shortcut=shortcut), + tl.Relu(), + ] + + +def IdentityBlock(kernel_size, filters, mode='train'): + """ResNet identical size block.""" + # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. + ks = kernel_size + filters1, filters2, filters3 = filters + main = [ + tl.Conv(filters1, (1, 1)), + tl.BatchNorm(mode=mode), + tl.Relu(), + tl.Conv(filters2, (ks, ks), padding='SAME'), + tl.BatchNorm(mode=mode), + tl.Relu(), + tl.Conv(filters3, (1, 1)), + tl.BatchNorm(mode=mode), + ] + return [ + tl.Residual(main), + tl.Relu(), + ] + + +def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'): + """ResNet. + + Args: + d_hidden: Dimensionality of the first hidden layer (multiplied later). + n_output_classes: Number of distinct output classes. + mode: Whether we are training or evaluating or doing inference. + + Returns: + The list of layers comprising a ResNet model with the given parameters. + """ + return tl.Model( + tl.ToFloat(), + tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), + tl.BatchNorm(mode=mode), + tl.Relu(), + tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), + ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode), + IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), + IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode), + ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2), + mode=mode), + IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), + IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), + IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], mode=mode), + ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2), + mode=mode), + IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), + IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), + IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), + IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), + IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], mode=mode), + ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2), + mode=mode), + IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), + IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], mode=mode), + tl.AvgPool(pool_size=(7, 7)), + tl.Flatten(), + tl.Dense(n_output_classes), + tl.LogSoftmax(), + ) + + +def WideResnetBlock(channels, strides=(1, 1), bn_momentum=0.9, mode='train'): + """WideResnet convolutional block.""" + return [ + tl.BatchNorm(momentum=bn_momentum, mode=mode), + tl.Relu(), + tl.Conv(channels, (3, 3), strides, padding='SAME'), + tl.BatchNorm(momentum=bn_momentum, mode=mode), + tl.Relu(), + tl.Conv(channels, (3, 3), padding='SAME'), + ] + + +def WideResnetGroup(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train'): + shortcut = [ + tl.Conv(channels, (3, 3), strides, padding='SAME'), + ] + return [ + tl.Residual(WideResnetBlock(channels, strides, bn_momentum=bn_momentum, + mode=mode), + shortcut=shortcut), + tl.Residual([WideResnetBlock(channels, (1, 1), bn_momentum=bn_momentum, + mode=mode) + for _ in range(n - 1)]), + ] + + +def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, + mode='train'): + """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. + + Args: + n_blocks: int, number of blocks in a group. total layers = 6n + 4. + widen_factor: int, widening factor of each group. k=1 is vanilla resnet. + n_output_classes: int, number of distinct output classes. + bn_momentum: float, momentum in BatchNorm. + mode: Whether we are training or evaluating or doing inference. + + Returns: + The list of layers comprising a WideResnet model with the given parameters. + """ + return tl.Model( + tl.ToFloat(), + tl.Conv(16, (3, 3), padding='SAME'), + WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum, + mode=mode), + WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), + bn_momentum=bn_momentum, mode=mode), + WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), + bn_momentum=bn_momentum, mode=mode), + tl.BatchNorm(momentum=bn_momentum, mode=mode), + tl.Relu(), + tl.AvgPool(pool_size=(8, 8)), + tl.Flatten(), + tl.Dense(n_output_classes), + tl.LogSoftmax(), + ) diff --git a/trax/models/resnet_test.py b/trax/models/resnet_test.py new file mode 100644 index 000000000..b85626f89 --- /dev/null +++ b/trax/models/resnet_test.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Resnet models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from trax import backend +from trax import layers as tl +from trax.models import resnet + + +class ResnetTest(absltest.TestCase): + + def test_resnet(self): + input_shape = (3, 256, 256, 3) + model = resnet.Resnet50(d_hidden=8, n_output_classes=10) + final_shape = tl.check_shape_agreement(model, input_shape) + self.assertEqual((3, 10), final_shape) + + def test_wide_resnet(self): + input_shape = (3, 32, 32, 3) + model = resnet.WideResnet(n_blocks=1, n_output_classes=10) + final_shape = tl.check_shape_agreement(model, input_shape) + self.assertEqual((3, 10), final_shape) + + + +if __name__ == '__main__': + absltest.main() diff --git a/trax/models/transformer.py b/trax/models/transformer.py new file mode 100644 index 000000000..640dc3716 --- /dev/null +++ b/trax/models/transformer.py @@ -0,0 +1,395 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer Models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from trax import layers as tl + + +def FeedForward(d_model, d_ff, dropout, layer_idx, mode): + """Feed-forward block with layer normalization at start.""" + return [ + tl.LayerNorm(), + tl.Dense(d_ff), + tl.Relu(), + tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode), + tl.Dense(d_model), + tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode), + ] + + +def EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode): + """Returns a layer sequence that implements a Transformer encoder block. + + The input to the layer sequence is a pair, (activations, mask), where the + mask was created from the original source tokens to prevent attending to the + padding part of the input. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + layer_idx: which layer are we at (for bookkeeping) + mode: str: 'train' or 'eval' + + Returns: + A sequence of layers that maps an (activations, mask) pair to an + (activations, mask) pair. + """ + attention = [ + tl.LayerNorm(), + tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), + tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode), + ] + feed_forward = [ + FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), + ] + return [ + tl.Residual(attention), + tl.Residual(feed_forward), + ] + + +def TransformerEncoder(vocab_size, + n_classes=10, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode='train'): + """Returns a Transformer encoder model. + + The input to the model is a tensor of tokens. + + Args: + vocab_size: int: vocab size + n_classes: how many classes on output + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train' or 'eval' + + Returns: + A Transformer model as a layer that maps from a tensor of tokens to + activations over a set of output classes. + """ + embedder = [ + tl.Embedding(d_model, vocab_size), + tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), + tl.PositionalEncoding(max_len=max_len), + ] + return tl.Model([ # tokens + tl.Dup(), # toks toks + tl.Parallel(embedder, tl.PaddingMask()), # vecs mask + [EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) + for i in range(n_layers)], # vecs mask + tl.Parallel([], tl.Drop()), # ____ 0 + tl.LayerNorm(), # vecs + tl.Mean(axis=1), # Average on length. # vecs + tl.Dense(n_classes), # vecs + tl.LogSoftmax(), # vecs + ]) + + +def DecoderBlock(d_model, d_ff, n_heads, d_attention_key, d_attention_value, + attention_type, dropout, share_qk, layer_idx, mode): + """Returns a layer sequence that implements a Transformer decoder block. + + The input to the layer sequence is an activation tensor. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + d_attention_key: int: depth of key vector for each attention head + d_attention_value: int: depth of value vector for each attention head + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + share_qk: bool, whether to share queries and keys + layer_idx: which layer are we at (for bookkeeping) + mode: str: 'train' or 'eval' + + Returns: + A sequence of layers that maps an activation tensor to an activation tensor. + """ + self_attention = [ + tl.LayerNorm(), # vec + tl.CausalAttention( + d_model, n_heads=n_heads, d_attention_key=d_attention_key, + d_attention_value=d_attention_value, attention_type=attention_type, + share_qk=share_qk, mode=mode), + tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode), + ] + feed_forward = [ + FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), + ] + return [ + tl.Residual(self_attention), + tl.Residual(feed_forward), + ] + + +def TransformerDecoder(vocab_size=None, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + d_attention_key=None, + d_attention_value=None, + attention_type=tl.DotProductCausalAttention, + dropout=0.1, + share_qk=False, + max_len=2048, + mode='train'): + """Returns a Transformer decoder model. + + The input to the model is either continuous or discrete - controlled by + vocab_size. Does not shift the input to the right, i.e. the output for + timestep t is based on inputs up to timestep t inclusively. + + Args: + vocab_size: int or None: vocab size if running on discrete input, None + otherwise. + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + d_attention_key: int: depth of key vector for each attention head + (default is d_model // n_heads) + d_attention_value: int: depth of value vector for each attention head + (default is d_model // n_heads) + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + share_qk: bool, whether to share queries and keys in decoder attention + max_len: int: maximum symbol length for positional encoding + mode: str: 'train' or 'eval' + + Returns: + A Transformer decoder as a layer that maps from a continuous or discrete + tensor to a continuous tensor. + """ + if vocab_size is None: + input_layer = tl.Dense + else: + input_layer = functools.partial(tl.Embedding, vocab_size=vocab_size) + return tl.Model( # vecs + input_layer(d_model), # vecs + tl.Dropout(rate=dropout, mode=mode), + tl.PositionalEncoding(max_len=max_len), + [DecoderBlock( # pylint: disable=g-complex-comprehension + d_model, d_ff, n_heads, d_attention_key, d_attention_value, + attention_type, dropout, share_qk, i, mode) + for i in range(n_layers)], # vecs + tl.LayerNorm(), # vecs + ) + + +def TransformerLM(vocab_size, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + d_attention_key=None, + d_attention_value=None, + attention_type=tl.DotProductCausalAttention, + dropout=0.1, + share_qk=False, + max_len=2048, + n_chunks=0, + mode='train'): + """Returns a Transformer language model. + + The input to the model is a tensor of tokens. (This model uses only the + decoder part of the overall Transformer.) + + Args: + vocab_size: int: vocab size + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + d_attention_key: int: depth of key vector for each attention head + (default is d_model // n_heads) + d_attention_value: int: depth of value vector for each attention head + (default is d_model // n_heads) + attention_type: subclass of tl.BaseCausalAttention: attention class to use + dropout: float: dropout rate (how much to drop out) + share_qk: bool, whether to share queries and keys in decoder attention + max_len: int: maximum symbol length for positional encoding + n_chunks: int: number of chunks (must match input pipeline) + mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference + + Returns: + A Transformer language model as a layer that maps from a tensor of tokens + to activations over a vocab set. + """ + if n_chunks == 0: + concatenate_chunks = split_chunks = [] + else: + concatenate_chunks = tl.Concatenate(n_items=n_chunks) + split_chunks = tl.Split(n_sections=n_chunks, axis=-2) + + embedder = [ + tl.Embedding(d_model, vocab_size), + tl.Dropout(rate=dropout, name='embedding', mode=mode), + tl.PositionalEncoding(max_len=max_len, mode=mode), + ] + + return tl.Model( # tokens (or chunked tuple of tokens) + concatenate_chunks, # tokens + tl.ShiftRight(mode=mode), # toks + embedder, # vecs + [DecoderBlock( # pylint: disable=g-complex-comprehension + d_model, d_ff, n_heads, d_attention_key, d_attention_value, + attention_type, dropout, share_qk, i, mode) + for i in range(n_layers)], # vecs + tl.LayerNorm(), # vecs + tl.Dense(vocab_size), # vecs + tl.LogSoftmax(), # vecs + split_chunks, # vecs (or chunked tuple of vecs) + ) + + +def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode): + """Transformer encoder-decoder layer. + + The input is a triple (decoder_input, mask, encoder) where the mask is + created from the original source to prevent attending to the padding part + of the encoder. + + Args: + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + layer_idx: which layer are we at (for bookkeeping) + mode: str: 'train' or 'eval' + + Returns: + the layer, returning a triple (decoder_activations, mask, encoder). + """ + decoder_self_attention = [ # vecs_d pmask vecs_e + tl.LayerNorm(), # vecs_d ..... ...... + tl.BasicCausalAttention( + d_model, n_heads=n_heads, dropout=dropout, mode=mode), + tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... + ] + decoder_to_encoder_attention = [ # vecs_d masks vecs_e + tl.LayerNorm(), # vecs_d masks vecs_e + tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e + tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... + tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... + tl.AttentionQKV( # (q k v masks ... --> vecs_d masks ...) + d_model, n_heads=n_heads, dropout=dropout, mode=mode), + tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e + ] + feed_forward = [ + FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), + ] + return [ # vecs_d masks vecs_e + tl.Residual(decoder_self_attention), # vecs_d masks vecs_e + tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e + tl.Residual(feed_forward), # vecs_d masks vecs_e + ] + + +def Transformer(input_vocab_size, + output_vocab_size=None, + d_model=512, + d_ff=2048, + n_layers=6, + n_heads=8, + dropout=0.1, + max_len=2048, + mode='train'): + """Returns a Transformer model. + + This model expects an input pair: target, source. + + Args: + input_vocab_size: int: vocab size of the source. + output_vocab_size: int (optional): vocab size of the target. If None, the + source and target are assumed to have the same vocab. + d_model: int: depth of embedding + d_ff: int: depth of feed-forward layer + n_layers: int: number of encoder/decoder layers + n_heads: int: number of attention heads + dropout: float: dropout rate (how much to drop out) + max_len: int: maximum symbol length for positional encoding + mode: str: 'train' or 'eval' + + Returns: + A Transformer model as a layer that maps from a target, source pair to + activations over a vocab set. + """ + in_embed = [ # tokens + tl.Embedding(d_model, input_vocab_size), # vecs + tl.Dropout(rate=dropout, mode=mode), # vecs + tl.PositionalEncoding(max_len=max_len), # vecs + ] + + if output_vocab_size is None: + output_vocab_size = input_vocab_size + out_embed = in_embed + else: + out_embed = [ # tokens + tl.Embedding(d_model, output_vocab_size), # vecs + tl.Dropout(rate=dropout, mode=mode), # vecs + tl.PositionalEncoding(max_len=max_len), # vecs + ] + + encoder_stack = ( # masks vectors --> masks vectors + [EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) + for i in range(n_layers)]) + + encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e + [EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) + for i in range(n_layers)]) + + # Input: encoder_side_tokens, decoder_side_tokens + return tl.Model( # tokens_e tokens_d + tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) + tl.Swap(), # toks_d toks_e .... + + # Encode. + tl.Parallel( # toks_d toks_e + [], [tl.Dup(), # ______ toks_e toks_e + tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks + encoder_stack, # ______ vecs_e masks + tl.LayerNorm(), # ______ vecs_e ..... + tl.Swap()]), # ______ masks vecs_e + + # Decode. # toks_d masks vecs_e + tl.ShiftRight(), # toks_d ..... ...... + out_embed, # vecs_d ..... ...... + tl.Dup(), # vecs_d vecs_d ..... ...... + tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... + encoder_decoder_stack, # vecs_d masks vecs_e + tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d + tl.LayerNorm(), # vecs_d + tl.Dense(output_vocab_size), # vecs_d + tl.LogSoftmax(), # vecs_d + ) diff --git a/trax/models/transformer_test.py b/trax/models/transformer_test.py new file mode 100644 index 000000000..e2210d3d2 --- /dev/null +++ b/trax/models/transformer_test.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Transformer models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as onp +from trax import backend +from trax import layers as tl +from trax.backend import numpy as np +from trax.models import transformer + + +class TransformerTest(parameterized.TestCase): + + def test_transformer_lm_forward_shape(self): + """Run the Transformer LM forward and check output shape.""" + vocab_size = 16 + input_shape = [3, 5] + model = transformer.TransformerLM( + vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) + final_shape = tl.check_shape_agreement( + model, tuple(input_shape), integer_inputs=True) + self.assertEqual(tuple(input_shape + [vocab_size]), final_shape) + + def _test_transformer_forward_shape(self, input_vocab_size, + output_vocab_size): + """Run the Transformer forward and check output shape.""" + single_input_shape = [3, 5] + input_shape = (tuple(single_input_shape), tuple(single_input_shape)) + model = transformer.Transformer( + input_vocab_size, output_vocab_size, + d_model=32, d_ff=64, n_layers=2, n_heads=2) + final_shape = tl.check_shape_agreement( + model, input_shape, integer_inputs=True) + expected_shape = (tuple(single_input_shape + + [output_vocab_size if output_vocab_size is not None + else input_vocab_size])) + self.assertEqual(expected_shape, final_shape[0]) + + @parameterized.named_parameters( + ('same_vocab', 16, None), + ('same_size', 16, 16), + ('different_size', 16, 50)) + def test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): + """Run the Transformer forward and check output shape.""" + self._test_transformer_forward_shape(input_vocab_size, output_vocab_size) + + + def _test_fast_inference(self, attention_type, length): + with backend.use_backend('jax'): + vocab_size = 16 + model_fn = functools.partial( + transformer.TransformerLM, + vocab_size=vocab_size, d_model=4, d_ff=8, n_layers=2, n_heads=2, + attention_type=attention_type, + ) + model_slow = model_fn(mode='eval') + model_fast = model_fn(mode='predict') + rng = backend.random.get_prng(0) + batch_size = 2 + # Given the same rng, both models initialize with the same parameters. + model_slow.initialize_once((batch_size, 1), np.int32, rng) + model_fast.initialize_once((batch_size, 1), np.int32, rng) + + buf = onp.zeros((batch_size, length), dtype=np.int32) + next_sym = onp.zeros((batch_size, 1), dtype=onp.int32) + + for index in range(length): + logits_slow = model_slow(buf, rng=rng) + logits_fast = model_fast(next_sym, rng=rng) + onp.testing.assert_array_almost_equal( + logits_slow[:, index, :], logits_fast[:, 0, :]) + next_sym = onp.random.randint(vocab_size, size=(batch_size, 1)) + buf[:, index] = next_sym[:, 0] + + def test_dot_product_causal_attention_fast_inference(self): + self._test_fast_inference(tl.DotProductCausalAttention, length=5) + + def test_time_bin_causal_attention_fast_inference(self): + attention = functools.partial(tl.TimeBinCausalAttention, bin_length=2) + self._test_fast_inference(attention, length=7) + +if __name__ == '__main__': + absltest.main() diff --git a/trax/notebooks/trax_demo_iclr2019.ipynb b/trax/notebooks/trax_demo_iclr2019.ipynb new file mode 100644 index 000000000..4cf5e5788 --- /dev/null +++ b/trax/notebooks/trax_demo_iclr2019.ipynb @@ -0,0 +1,854 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Trax Demo", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ySEmBgmqMSIJ", + "colab_type": "text" + }, + "source": [ + "##### Copyright 2019 Google LLC.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + "https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o4WGihMLneYq", + "colab_type": "text" + }, + "source": [ + "# Trax: Train Models in JAX\n", + "\n", + "[JAX](https://github.com/google/jax) allows you to write [numpy](https://www.numpy.org/) and run it fast on accelerators.\n", + "\n", + "This makes ML research more *fun* and *clear* so we made\n", + "* [Trax](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/trax): a library of models in JAX.\n", + "\n", + "In this demo we show how to:\n", + "* Train a Trax model on a toy copy problem.\n", + "* Decode from a pre-trained [Transformer](https://arxiv.org/abs/1706.03762) language model.\n", + "* Define [Transformer](https://arxiv.org/abs/1706.03762) from scratch in Trax.\n", + "* Do research in Trax: play with hard attention to see how it impacts training and results.\n", + "\n", + "We would like your feedback!\n", + "* What are the parts you like or dislike in JAX and Trax?\n", + "* Will you start doing your research in Trax? If not, why? What would change your mind?\n", + "* What should we focus on? Speed, cleanliness, memory use?\n", + "* If you cannot tell us in person, please add your feedback on [this github issue](https://github.com/tensorflow/tensor2tensor/issues/1478).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8YQw0hySTVlK", + "colab_type": "text" + }, + "source": [ + "## Installs\n", + "\n", + "We install jax and trax and download a pretrained model and vocab file." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "vAWJVzYRnbDU", + "colab_type": "code", + "outputId": "6cdeff6f-3fc9-406f-feaf-fd1f8d9de775", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 578 + } + }, + "source": [ + "# Install JAX for GPU and Tensor2Tensor.\n", + "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.14-cp36-none-linux_x86_64.whl\n", + "!pip install --upgrade -q jax==0.1.27\n", + "!pip install --upgrade -q tensor2tensor==1.13.4\n", + "# Grab language-model checkpoint and vocab file.\n", + "!rm -f model.pkl\n", + "!wget https://storage.googleapis.com/traxdemo/model.pkl\n", + "!wget https://storage.googleapis.com/traxdemo/vocab.lm1b.en.32768\n", + "# Show GPU type.\n", + "!nvidia-smi -L" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[K |████████████████████████████████| 44.6MB 1.2MB/s \n", + "\u001b[K |████████████████████████████████| 174kB 3.5MB/s \n", + "\u001b[K |████████████████████████████████| 61kB 24.4MB/s \n", + "\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for opt-einsum (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[K |████████████████████████████████| 1.4MB 3.4MB/s \n", + "\u001b[K |████████████████████████████████| 686kB 45.8MB/s \n", + "\u001b[K |████████████████████████████████| 143kB 40.2MB/s \n", + "\u001b[K |████████████████████████████████| 296kB 32.6MB/s \n", + "\u001b[?25h Building wheel for pypng (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "--2019-05-14 22:57:21-- https://storage.googleapis.com/traxdemo/model.pkl\n", + "Resolving storage.googleapis.com (storage.googleapis.com)... 209.85.234.128, 2607:f8b0:4001:c12::80\n", + "Connecting to storage.googleapis.com (storage.googleapis.com)|209.85.234.128|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 211170062 (201M) [application/octet-stream]\n", + "Saving to: ‘model.pkl’\n", + "\n", + "model.pkl 100%[===================>] 201.39M 101MB/s in 2.0s \n", + "\n", + "2019-05-14 22:57:23 (101 MB/s) - ‘model.pkl’ saved [211170062/211170062]\n", + "\n", + "--2019-05-14 22:57:23-- https://storage.googleapis.com/traxdemo/vocab.lm1b.en.32768\n", + "Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.183.128, 2607:f8b0:4001:c07::80\n", + "Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.183.128|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 297760 (291K) [application/octet-stream]\n", + "Saving to: ‘vocab.lm1b.en.32768’\n", + "\n", + "vocab.lm1b.en.32768 100%[===================>] 290.78K --.-KB/s in 0.007s \n", + "\n", + "2019-05-14 22:57:24 (40.8 MB/s) - ‘vocab.lm1b.en.32768’ saved [297760/297760]\n", + "\n", + "GPU 0: Tesla T4 (UUID: GPU-1959cc75-52ab-cf03-e5fa-36aee0d59bc5)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vvFrqacVS6B6", + "colab_type": "text" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dYq8J8uBn9ZC", + "colab_type": "code", + "outputId": "db8ca8de-164c-4355-8abb-a493e7f9f393", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + } + }, + "source": [ + "from six.moves import cPickle\n", + "import os\n", + "import datetime\n", + "import random\n", + "\n", + "import numpy as onp\n", + "from matplotlib import pyplot as plt\n", + "\n", + "from jax.ops import index, index_update\n", + "\n", + "from trax import trax\n", + "from trax import layers as tl\n", + "from trax import inputs as trax_input\n", + "from trax import models as trax_models\n", + "from trax import optimizers as trax_optimizers\n", + "from trax import backend\n", + "from trax.backend import numpy as np\n", + "from trax.backend import random as trax_random" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", + "For more information, please see:\n", + " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", + " * https://github.com/tensorflow/addons\n", + "If you depend on functionality not listed there, please file an issue.\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zR6RVHx4lPzA", + "colab_type": "text" + }, + "source": [ + "# Toy Copy Problem\n", + "\n", + "Here we define batched random integer inputs for a trivial sequence-copy learning task." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wGmWmpIslQYv", + "colab_type": "code", + "colab": {} + }, + "source": [ + "VOCAB_SIZE = 128\n", + "def toy_problem_inputs(num_devices, batch_size=64,\n", + " train_lengths=[10, 20], eval_lengths=[20]):\n", + " \"\"\"Make Inputs for the toy problem of the language 0w0w for w in [1..127]*.\n", + "\n", + " Args:\n", + " num_devices: how many devices to build the inputs for (assert 1 for colab).\n", + " batch_size: how large are the batches.\n", + " train_lengths: lengths of w for training.\n", + " eval_lengths: lengths of w for eval.\n", + "\n", + " Returns:\n", + " trax.inputs.Inputs\n", + " \"\"\"\n", + " assert num_devices == 1\n", + " def random_minibatches(length_list):\n", + " \"\"\"Generate a stream of random mini-batches.\"\"\"\n", + " while True:\n", + " length = random.choice(length_list)\n", + " w = onp.random.randint(low=1, high=VOCAB_SIZE-1,\n", + " size=(batch_size, length // 2))\n", + " zero = onp.zeros([batch_size, 1], onp.int32)\n", + " x = onp.concatenate([zero, w, zero, w], axis=1)\n", + " yield (x, x) # In a language model input and output are the same.\n", + "\n", + " return trax_input.Inputs(\n", + " train_stream=lambda: random_minibatches(train_lengths),\n", + " train_eval_stream=lambda: random_minibatches(train_lengths),\n", + " eval_stream=lambda: random_minibatches(eval_lengths),\n", + " input_shape=(None,))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "eU0mpaf1lRky", + "colab_type": "code", + "outputId": "bf94086c-5d97-462b-b565-d4ba5f59b6c4", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "inputs = toy_problem_inputs(1)\n", + "print(next(inputs.train_stream())[0][0])" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[ 0 68 91 99 107 115 113 111 17 102 48 0 68 91 99 107 115 113\n", + " 111 17 102 48]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KvNaSWu5g2Vm", + "colab_type": "text" + }, + "source": [ + "## Baseline Transformer on Toy Problem" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "AGDmtrgcl73M", + "colab_type": "code", + "outputId": "4c0f12e9-10ec-4e67-9f15-d2cc7084c083", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 748 + } + }, + "source": [ + "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n", + "output_dir = os.path.expanduser(\"~/trax_lm_%s\" % timestamp)\n", + "def model(mode):\n", + " return trax_models.TransformerLM(\n", + " VOCAB_SIZE, feature_depth=128,\n", + " feedforward_depth=256, num_layers=3,\n", + " num_heads=4, mode=mode)\n", + "_ = trax.train(model=model,\n", + " inputs=toy_problem_inputs,\n", + " output_dir=output_dir,\n", + " train_steps=3000,\n", + " eval_steps=10,\n", + " eval_frequency=1000)" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Step 0: Starting training using 1 devices\n", + "\n", + "Step 1: Ran 1 train steps in 36.77 secs\n", + "Step 1: Total trainable parameters size: 692736\n", + "Step 1: Evaluation\n", + "Step 1: train accuracy | 0.00616714\n", + "Step 1: train neg_log_perplexity | -5.06836748\n", + "Step 1: train loss | 5.06836748\n", + "Step 1: eval accuracy | 0.00610795\n", + "Step 1: eval neg_log_perplexity | -5.20451212\n", + "Step 1: eval loss | 5.20451212\n", + "Step 1: Finished evaluation\n", + "\n", + "Step 1000: Ran 999 train steps in 89.13 secs\n", + "Step 1000: Evaluation\n", + "Step 1000: train accuracy | 0.45719695\n", + "Step 1000: train neg_log_perplexity | -2.71764731\n", + "Step 1000: train loss | 2.71764731\n", + "Step 1000: eval accuracy | 0.41278410\n", + "Step 1000: eval neg_log_perplexity | -2.94052887\n", + "Step 1000: eval loss | 2.94052887\n", + "Step 1000: Finished evaluation\n", + "\n", + "Step 2000: Ran 1000 train steps in 15.61 secs\n", + "Step 2000: Evaluation\n", + "Step 2000: train accuracy | 0.43169984\n", + "Step 2000: train neg_log_perplexity | -2.82782769\n", + "Step 2000: train loss | 2.82782769\n", + "Step 2000: eval accuracy | 0.41278410\n", + "Step 2000: eval neg_log_perplexity | -2.92255998\n", + "Step 2000: eval loss | 2.92255998\n", + "Step 2000: Finished evaluation\n", + "\n", + "Step 3000: Ran 1000 train steps in 15.64 secs\n", + "Step 3000: Evaluation\n", + "Step 3000: train accuracy | 0.45053267\n", + "Step 3000: train neg_log_perplexity | -2.73254609\n", + "Step 3000: train loss | 2.73254609\n", + "Step 3000: eval accuracy | 0.41249999\n", + "Step 3000: eval neg_log_perplexity | -2.92720962\n", + "Step 3000: eval loss | 2.92720962\n", + "Step 3000: Finished evaluation\n", + "Step 3000: Training done\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eapBBkRUuho7", + "colab_type": "text" + }, + "source": [ + "# Decoding from a Pre-Trained Transformer Language Model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "H6hVQ3v5iC00", + "colab_type": "code", + "outputId": "812949cc-4294-4a42-f55a-c40f65e151f8", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 187 + } + }, + "source": [ + "# load model checkpoint\n", + "with open(\"model.pkl\", \"rb\") as f:\n", + " (params, step, history) = cPickle.load(f, encoding=\"latin1\")\n", + "\n", + "# lm1b subword vocab\n", + "def clean(x):\n", + " return x[1:-2]\n", + "with open(\"vocab.lm1b.en.32768\", \"r\") as fp:\n", + " vocab = list(map(clean, fp.readlines()))\n", + "vocab_map = {v:idx for idx,v in enumerate(vocab)}\n", + "\n", + "list(enumerate(vocab))[:10]" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[(0, '_'),\n", + " (1, '_'),\n", + " (2, 'the_'),\n", + " (3, ' , _'),\n", + " (4, ' ._'),\n", + " (5, 'to_'),\n", + " (6, 'of_'),\n", + " (7, 'a_'),\n", + " (8, 'and_'),\n", + " (9, 'in_')]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "W-7s9RXQNIru", + "colab_type": "code", + "colab": {} + }, + "source": [ + "tlm = trax_models.TransformerLM(\n", + " dropout=0.1, \n", + " feature_depth=512, \n", + " feedforward_depth=2048, \n", + " max_len=2048, \n", + " mode='eval', \n", + " num_heads=8, \n", + " num_layers=6, \n", + " vocab_size=32000)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iLdtplDpdTMr", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def gumbel_sample(v, temperature=0.8):\n", + " u = onp.random.uniform(low=1e-9, high=1.0, size=v.shape)\n", + " g = -onp.log(-onp.log(u))\n", + " return np.argmax(v + g * temperature)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "IHSbtHzPjW6i", + "colab_type": "code", + "outputId": "7a8306b7-6c6b-41ba-c5aa-c1a76d9d8037", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + } + }, + "source": [ + "prompt = \"Please_\"\n", + "num_samples = 5\n", + "max_length = 20\n", + "for _ in range(num_samples):\n", + " enc = [vocab_map[w] for w in str.split(prompt)]\n", + " pos = len(enc)\n", + " rng = trax_random.get_prng(0)\n", + " data = np.zeros((1, 50), dtype=np.int32)\n", + " data = index_update(data, index[0, 0:pos], enc)\n", + "\n", + " while pos < max_length:\n", + " tmp = tlm(data, params=params, rng=rng)\n", + " next_sym = gumbel_sample(tmp[0, pos])\n", + " data = index_update(data, index[0, pos], next_sym)\n", + " pos += 1\n", + " if int(next_sym) == 1:\n", + " break\n", + "\n", + " print(\"\".join([vocab[idx] for idx in onp.array(data)[0, 0:pos]]))" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Please_write_to_him_to_tell_him_about_the_Wallace_and_Gromit_films_. _and_to_give_him_this_\n", + "Please_do_not_turn_to_making_sure_your_children_are_already_in_school_or_that_you_have_school_ .__\n", + "Please_read_the_full_prospectus_to_see_if_the_proposed_transaction_may_be_accurate_ .__\n", + "Please_note_that_the_new_policy_has_been_strengthened_by_the_fact_that_Britney_Spears_ ' _mother_ , _Janet_Jackson_\n", + "Please_ , _please_aim_at_your_brother_ , _if_you_want_to_ .__\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ym8otS7HpUIO", + "colab_type": "text" + }, + "source": [ + "# Transformer from Scratch\n", + "\n", + "Here we re-implement multiheaded self-attention and a transformer language model from scratch using only a few simple linear primitives from trax.\n", + "\n", + "Note in particular the commented modifications in the core __DotProductAttention__ function as an example of how easy it is to modify layers and models for research using Trax." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uw-GIdm2p_4X", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def DotProductAttention(query, key, value, mask, dropout, mode, rng, hard_k=4):\n", + " \"\"\"Core dot product self-attention.\n", + " Args:\n", + " query: array of representations\n", + " key: array of representations\n", + " value: array of representations\n", + " mask: attention-mask, gates attention\n", + " dropout: float: dropout rate\n", + " mode: 'eval' or 'train': whether to use dropout\n", + " rng: JAX PRNGKey: subkey for disposable use\n", + " Returns:\n", + " Self attention for q, k, v arrays.\n", + " \"\"\"\n", + " depth = np.shape(query)[-1]\n", + " dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)\n", + " if mask is not None:\n", + " dots = np.where(mask, dots, -1e9)\n", + " # Softmax.\n", + " dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True))\n", + " # ----------------------------------------------------------------------\n", + " # As an example of a simple research modification, we modify the typical \n", + " # dot-product attention mechanism with top-k \"hard attention\":\n", + " # ----------------------------------------------------------------------\n", + " if hard_k > 0:\n", + " top_k = np.sort(dots)[..., -hard_k] # Get the top-kth weight.\n", + " dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones).\n", + " dots = np.maximum(dots, 0)\n", + " dots /= np.sum(dots, axis=-1, keepdims=True) # Re-normalize.\n", + " # ----------------------------------------------------------------------\n", + " if dropout >= 1.0:\n", + " raise ValueError('Dropout rates must be lower than 1.')\n", + " if dropout is not None and dropout > 0.0 and mode == 'train':\n", + " keep = backend.random.bernoulli(rng, 1.0 - dropout, dots.shape)\n", + " dots = np.where(keep, dots / (1.0 - dropout), 0)\n", + " out = np.matmul(dots, value)\n", + " # Uncomment to see an example Trax stack trace to this point:\n", + " # ----------------------------------------------------------------------\n", + " # raise ValueError(\"err\")\n", + " # ----------------------------------------------------------------------\n", + " return out\n", + "\n", + "\n", + "def _multihead_attention_output_shape( # pylint: disable=invalid-name\n", + " input_shapes, **unused_kwargs):\n", + " \"\"\"Helper: calculate multihead attention output shape.\"\"\"\n", + " q_shape = input_shapes[0][0] # Inputs are ((q, k, v), mask).\n", + " mask_shape = input_shapes[1]\n", + " return q_shape, mask_shape\n", + "\n", + "\n", + "@tl.layer(output_shape=_multihead_attention_output_shape)\n", + "def PureMultiHeadedAttention(x, params, num_heads=8, dropout=0.0,\n", + " mode='train', **kwargs):\n", + " \"\"\"Pure transformer-style multi-headed attention.\n", + " Args:\n", + " x: inputs ((q, k, v), mask)\n", + " params: parameters (none)\n", + " num_heads: int: number of attention heads\n", + " dropout: float: dropout rate\n", + " mode: str: 'train' or 'eval'\n", + " **kwargs: other arguments including the rng\n", + " Returns:\n", + " Pure Multi-headed attention result, and the mask.\n", + " \"\"\"\n", + " del params\n", + " rng = kwargs.get('rng', None)\n", + " (q, k, v), mask = x\n", + " feature_depth = q.shape[-1]\n", + " assert feature_depth % num_heads == 0\n", + " head_depth = feature_depth // num_heads\n", + " nbatch = np.shape(q)[0]\n", + " # nbatch, seqlen, feature_depth --> nbatch, num_heads, seqlen, head_depth\n", + " def SplitHeads(x):\n", + " return np.transpose(\n", + " np.reshape(x, (nbatch, -1, num_heads, head_depth)), (0, 2, 1, 3))\n", + " # nbatch, num_heads, seqlen, head_depth --> nbatch, seqlen, feature_depth\n", + " def JoinHeads(x): # pylint: disable=invalid-name\n", + " return np.reshape(\n", + " np.transpose(x, (0, 2, 1, 3)), (nbatch, -1, num_heads*head_depth))\n", + " # Split heads, dot-product attention, rejoin heads.\n", + " res = JoinHeads(\n", + " DotProductAttention(\n", + " SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,\n", + " dropout=dropout, mode=mode, rng=rng))\n", + " return res, mask # Keep the mask.\n", + "\n", + "\n", + "def MultiHeadedAttentionQKV(\n", + " feature_depth, num_heads=8, dropout=0.0, mode='train'):\n", + " \"\"\"Transformer-style multi-headed attention.\n", + " Accepts inputs of the form (q, k, v), mask.\n", + " Args:\n", + " feature_depth: int: depth of embedding\n", + " num_heads: int: number of attention heads\n", + " dropout: float: dropout rate\n", + " mode: str: 'train' or 'eval'\n", + " Returns:\n", + " Multi-headed self-attention result and the mask.\n", + " \"\"\"\n", + " return tl.Serial(\n", + " tl.Parallel(\n", + " tl.Parallel(\n", + " tl.Dense(feature_depth),\n", + " tl.Dense(feature_depth),\n", + " tl.Dense(feature_depth),\n", + " ),\n", + " tl.Copy()\n", + " ),\n", + " PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter\n", + " feature_depth=feature_depth, num_heads=num_heads,\n", + " dropout=dropout, mode=mode),\n", + " tl.Parallel(tl.Dense(feature_depth), tl.Copy())\n", + " )\n", + "\n", + "\n", + "def MultiHeadedAttention(\n", + " feature_depth, num_heads=8, dropout=0.0, mode='train'):\n", + " \"\"\"Transformer-style multi-headed attention.\n", + " Accepts inputs of the form (x, mask) and constructs (q, k, v) from x.\n", + " Args:\n", + " feature_depth: int: depth of embedding\n", + " num_heads: int: number of attention heads\n", + " dropout: float: dropout rate\n", + " mode: str: 'train' or 'eval'\n", + " Returns:\n", + " Multi-headed self-attention layer.\n", + " \"\"\"\n", + " return tl.Serial(\n", + " tl.Parallel(\n", + " # q = k = v = first input\n", + " tl.Branch(\n", + " tl.Copy(), tl.Copy(), tl.Copy()),\n", + " tl.Copy() # pass the mask\n", + " ),\n", + " MultiHeadedAttentionQKV( # pylint: disable=no-value-for-parameter\n", + " feature_depth, num_heads=num_heads, dropout=dropout, mode=mode),\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ge42t7VZl-d2", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def ResidualFeedForward(feature_depth,\n", + " feedforward_depth,\n", + " dropout,\n", + " mode):\n", + " \"\"\"Residual feed-forward layer with normalization at start.\"\"\"\n", + " return tl.Residual(\n", + " tl.LayerNorm(),\n", + " tl.Dense(feedforward_depth),\n", + " tl.Relu(),\n", + " tl.Dropout(rate=dropout, mode=mode),\n", + " tl.Dense(feature_depth),\n", + " tl.Dropout(rate=dropout, mode=mode)\n", + " )\n", + "\n", + "\n", + "def DecoderLayer(feature_depth,\n", + " feedforward_depth,\n", + " num_heads,\n", + " dropout,\n", + " mode):\n", + " \"\"\"Transformer decoder layer.\n", + " Args:\n", + " feature_depth: int: depth of embedding\n", + " feedforward_depth: int: depth of feed-forward layer\n", + " num_heads: int: number of attention heads\n", + " dropout: float: dropout rate (how much to drop out)\n", + " mode: str: 'train' or 'eval'\n", + " Returns:\n", + " the layer.\n", + " \"\"\"\n", + " return tl.Serial(\n", + " tl.Residual( # Self-attention block.\n", + " tl.LayerNorm(),\n", + " tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask.\n", + " # We replace the \"stock\" self-attention layer with the one defined\n", + " # above:\n", + " # tl.MultiHeadedAttention(feature_depth, num_heads=num_heads,\n", + " # dropout=dropout, mode=mode),\n", + " MultiHeadedAttention(feature_depth, num_heads=num_heads,\n", + " dropout=dropout, mode=mode),\n", + " tl.Select(0), # Drop the mask.\n", + " tl.Dropout(rate=dropout, mode=mode)\n", + " ),\n", + " ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)\n", + " )\n", + "\n", + "\n", + "def TransformerLM(vocab_size,\n", + " feature_depth=512,\n", + " feedforward_depth=2048,\n", + " num_layers=6,\n", + " num_heads=8,\n", + " dropout=0.1,\n", + " max_len=2048,\n", + " mode='train'):\n", + " \"\"\"Transformer language model (only uses the decoder part of Transformer).\n", + " Args:\n", + " vocab_size: int: vocab size\n", + " feature_depth: int: depth of embedding\n", + " feedforward_depth: int: depth of feed-forward layer\n", + " num_layers: int: number of encoder/decoder layers\n", + " num_heads: int: number of attention heads\n", + " dropout: float: dropout rate (how much to drop out)\n", + " max_len: int: maximum symbol length for positional encoding\n", + " mode: str: 'train' or 'eval'\n", + " Returns:\n", + " the layer.\n", + " \"\"\"\n", + " return tl.Serial(\n", + " tl.ShiftRight(),\n", + " tl.Embedding(feature_depth, vocab_size),\n", + " tl.Dropout(rate=dropout, mode=mode),\n", + " tl.PositionalEncoding(max_len=max_len),\n", + " tl.Serial(*[DecoderLayer(feature_depth, feedforward_depth, num_heads,\n", + " dropout, mode)\n", + " for _ in range(num_layers)]),\n", + " tl.LayerNorm(),\n", + " tl.Dense(vocab_size),\n", + " tl.LogSoftmax()\n", + " )" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WZxnwjAEqYDh", + "colab_type": "code", + "outputId": "f90e965d-2625-4e56-9038-65c087639051", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 748 + } + }, + "source": [ + "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M\")\n", + "output_dir = os.path.expanduser(\"~/trax_lm_%s\" % timestamp)\n", + "def new_model(mode):\n", + " return TransformerLM(\n", + " VOCAB_SIZE, feature_depth=128,\n", + " feedforward_depth=256, num_layers=3,\n", + " num_heads=4, mode=mode)\n", + "_ = trax.train(model=new_model,\n", + " inputs=toy_problem_inputs,\n", + " output_dir=output_dir,\n", + " train_steps=3000,\n", + " eval_steps=10,\n", + " eval_frequency=1000)" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Step 0: Starting training using 1 devices\n", + "\n", + "Step 1: Ran 1 train steps in 42.29 secs\n", + "Step 1: Total trainable parameters size: 692736\n", + "Step 1: Evaluation\n", + "Step 1: train accuracy | 0.00686553\n", + "Step 1: train neg_log_perplexity | -5.42891455\n", + "Step 1: train loss | 5.42891455\n", + "Step 1: eval accuracy | 0.00809659\n", + "Step 1: eval neg_log_perplexity | -5.39403439\n", + "Step 1: eval loss | 5.39403439\n", + "Step 1: Finished evaluation\n", + "\n", + "Step 1000: Ran 999 train steps in 109.64 secs\n", + "Step 1000: Evaluation\n", + "Step 1000: train accuracy | 0.12875238\n", + "Step 1000: train neg_log_perplexity | -4.29979420\n", + "Step 1000: train loss | 4.29979420\n", + "Step 1000: eval accuracy | 0.09928977\n", + "Step 1000: eval neg_log_perplexity | -4.45948172\n", + "Step 1000: eval loss | 4.45948172\n", + "Step 1000: Finished evaluation\n", + "\n", + "Step 2000: Ran 1000 train steps in 16.89 secs\n", + "Step 2000: Evaluation\n", + "Step 2000: train accuracy | 0.53104877\n", + "Step 2000: train neg_log_perplexity | -2.33383632\n", + "Step 2000: train loss | 2.33383632\n", + "Step 2000: eval accuracy | 0.54900569\n", + "Step 2000: eval neg_log_perplexity | -2.24813342\n", + "Step 2000: eval loss | 2.24813342\n", + "Step 2000: Finished evaluation\n", + "\n", + "Step 3000: Ran 1000 train steps in 16.91 secs\n", + "Step 3000: Evaluation\n", + "Step 3000: train accuracy | 0.56715208\n", + "Step 3000: train neg_log_perplexity | -2.15219927\n", + "Step 3000: train loss | 2.15219927\n", + "Step 3000: eval accuracy | 0.54928976\n", + "Step 3000: eval neg_log_perplexity | -2.25436211\n", + "Step 3000: eval loss | 2.25436211\n", + "Step 3000: Finished evaluation\n", + "Step 3000: Training done\n" + ], + "name": "stdout" + } + ] + } + ] +} diff --git a/trax/optimizers/__init__.py b/trax/optimizers/__init__.py new file mode 100644 index 000000000..44a400d4c --- /dev/null +++ b/trax/optimizers/__init__.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Optimizers defined in trax.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin + +from trax.optimizers import base + + +def opt_configure(*args, **kwargs): + kwargs["module"] = "trax.optimizers" + return gin.external_configurable(*args, **kwargs) + +# Optimizers (using upper-case names). +# pylint: disable=invalid-name +SGD = opt_configure(base.SGD) +Momentum = opt_configure(base.Momentum) +RMSProp = opt_configure(base.RMSProp) +Adam = opt_configure(base.Adam) +Adafactor = opt_configure(base.Adafactor) +SM3 = opt_configure(base.SM3) diff --git a/trax/optimizers/base.py b/trax/optimizers/base.py new file mode 100644 index 000000000..060087ed0 --- /dev/null +++ b/trax/optimizers/base.py @@ -0,0 +1,465 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax base optimizer class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from trax.backend import numpy as np +from trax.layers import base as layers + + +def tree_flatten(tree): + """Flatten a tree into a list.""" + if isinstance(tree, (list, tuple)): + # In python, sum of lists starting from [] is the concatenation. + return sum([tree_flatten(t) for t in tree], []) + if isinstance(tree, dict): + # Only use the values in case of a dictionary node. + return sum([tree_flatten(v) for v in tree.values()], []) + return [tree] + + +def tree_unflatten(flat, tree): + """Unflatten a list into a tree given the tree shape as second argument. + + Args: + flat: a flat list of elements to be assembled into a tree. + tree: a tree with the structure we want to have in the new tree. + + Returns: + A pair (new_tree, rest_of_flat) where the new tree that has the structure + of tree but with leaves from flat, and the remaining elements of flat if + more were provided than the number of leaves of tree (useful for recursion). + """ + if isinstance(tree, (list, tuple)): + new_tree, rest = [], flat + for t in tree: + new_t, rest = tree_unflatten(rest, t) + new_tree.append(new_t) + new_tree = tuple(new_tree) if isinstance(tree, tuple) else new_tree + return new_tree, rest + if isinstance(tree, dict): + new_tree, rest = {}, flat + for k in tree: + new_v, rest = tree_unflatten(rest, tree[k]) + new_tree[k] = new_v + return new_tree, rest + return flat[0], flat[1:] + + +class Optimizer(object): + """Optimizer object, base class. Maps per-parameter functions to trees.""" + + def __init__(self, learning_rate, **init_opt_params): + """Initialize the optimizer. + + Takes the initial optimizer parameters as positional arguments. They are fed + back to the optimizer in tree_update, in the same order. They can be changed + between updates, e.g. for learning rate schedules. + + The constructor should be overridden in derived classes to give names to the + optimizer parameters, so the gin configuration can set them. + + Args: + learning_rate: The initial learning rate. + **init_opt_params: Initial values of any additional optimizer parameters. + """ + init_opt_params["learning_rate"] = learning_rate + self._init_opt_params = { + name: np.array(value) for (name, value) in init_opt_params.items() + } + + def init(self, params): + """Create optimizer slots for the given parameters.""" + raise NotImplementedError + + def update(self, step, grads, params, slots, opt_params): + """Update a single parameter array. + + Args: + step: Current step. + grads: Gradients. + params: Parameters. + slots: Optimizer slots (e.g. gradient moments). + opt_params: Optimizer (hyper)parameters (e.g. learning rate, momentum). + + Returns: + (new_params, new_slots) + """ + raise NotImplementedError + + # End subclass interface. + + def tree_init(self, param_tree): + return ( + [self.init(param) for param in tree_flatten(param_tree)], + self._init_opt_params, + ) + + def _update_and_check(self, step, grads, params, slots, opt_params): + """Update a single parameter array and check types.""" + new_params, new_slots = self.update( + step, grads, params, slots, opt_params) + if isinstance(params, np.ndarray): + assert isinstance(new_params, np.ndarray), ( + "The type of the new parameter values should be np.ndarray; got %s" % + type(new_params)) + assert new_params.dtype == params.dtype, ( + "The dtype of the new parameter values (%s) is not the same as the " + "old one (%s)" % (new_params.dtype, params.dtype)) + return new_params, new_slots + + def tree_update(self, step, grad_tree, param_tree, slots, opt_params): + grads_flat = tree_flatten(grad_tree) + params_flat = tree_flatten(param_tree) + updated_pairs = [ + self._update_and_check(step, grad, param, slot, opt_params) + for (grad, param, slot) in zip(grads_flat, params_flat, slots) + ] + new_params_flat, new_slots = zip(*updated_pairs) + new_params, _ = tree_unflatten(new_params_flat, param_tree) + return new_params, new_slots + + +# Utilities. + + +def l2_norm(tree): + """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" + leaves = tree_flatten(tree) + return np.sqrt(sum(np.vdot(x, x) for x in leaves)) + + +def clip_grads(grad_tree, max_norm): + """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" + norm = l2_norm(grad_tree) + normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm)) + return layers.nested_map(normalize, grad_tree) + + +# Optimizers. + + +class SGD(Optimizer): + """Plain SGD optimizer.""" + + def init(self, params): + return None + + def update(self, step, grads, params, slots, opt_params): + del step + del slots + learning_rate = opt_params["learning_rate"] + return params - (learning_rate * grads).astype(params.dtype), None + + +class Momentum(Optimizer): + """Nesterov momentum optimizer.""" + + def __init__(self, learning_rate, mass=0.9, weight_decay_rate=1e-5): # pylint: disable=useless-super-delegation + super(Momentum, self).__init__( + learning_rate=learning_rate, + mass=mass, + weight_decay_rate=weight_decay_rate, + ) + + def init(self, params): + return np.zeros_like(params) + + def update(self, step, grads, params, velocity, opt_params): + del step + learning_rate = opt_params["learning_rate"] + mass = opt_params["mass"] + weight_decay_rate = opt_params["weight_decay_rate"] + new_velocity = mass * velocity + grads + new_params = (1 - weight_decay_rate) * params - ( + learning_rate * (mass * new_velocity + grads)).astype(params.dtype) + return (new_params, new_velocity) + + +class RMSProp(Optimizer): + """RMSProp optimizer.""" + + def __init__(self, learning_rate, gamma=0.9, eps=1e-8): # pylint: disable=useless-super-delegation + super(RMSProp, self).__init__( + learning_rate=learning_rate, + gamma=gamma, + eps=eps, + ) + + def init(self, params): + return np.ones_like(params) + + def update(self, step, grads, params, avg_sq_grad, opt_params): + del step + learning_rate = opt_params["learning_rate"] + gamma = opt_params["gamma"] + eps = opt_params["eps"] + avg_sq_grad = avg_sq_grad * gamma + grads**2 * (1. - gamma) + params = params - (learning_rate * grads / + (np.sqrt(avg_sq_grad) + eps)).astype(params.dtype) + return params, avg_sq_grad + + +class Adam(Optimizer): + """Adam optimizer.""" + + def __init__(self, learning_rate, weight_decay_rate=1e-5, # pylint: disable=useless-super-delegation + b1=0.9, b2=0.999, eps=1e-5): + """Create the Adam optimizer. + + Args: + learning_rate: a postitive scalar value for the initial learning rate. + weight_decay_rate: rate at which to decay weights. + b1: optional, a positive scalar value for beta_1, the exponential decay + rate for the first moment estimates (default 0.9). + b2: optional, a positive scalar value for beta_2, the exponential decay + rate for the second moment estimates (default 0.999). + eps: optional, a positive scalar value for epsilon, a small constant for + numerical stability (default 1e-8). + """ + super(Adam, self).__init__( + learning_rate=learning_rate, + weight_decay_rate=weight_decay_rate, + b1=b1, + b2=b2, + eps=eps, + ) + + def init(self, params): + m = np.zeros_like(params) + v = np.zeros_like(params) + return m, v + + def update(self, step, grads, params, slots, opt_params): + m, v = slots + learning_rate = opt_params["learning_rate"] + weight_decay_rate = opt_params["weight_decay_rate"] + b1 = opt_params["b1"] + b2 = opt_params["b2"] + eps = opt_params["eps"] + m = (1 - b1) * grads + b1 * m # First moment estimate. + v = (1 - b2) * (grads ** 2) + b2 * v # Second moment estimate. + mhat = m / (1 - b1 ** (step + 1)) # Bias correction. + vhat = v / (1 - b2 ** (step + 1)) + params = (1 - weight_decay_rate) * params - ( + learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(params.dtype) + return params, (m, v) + + +class Adafactor(Optimizer): + """Adafactor optimizer.""" + + def __init__(self, + learning_rate, + factored=True, + multiply_by_parameter_scale=True, + do_clipping=True, + do_momentum=False, + beta1=0.0, + decay_rate=0.8, + clipping_threshold=1.0, + weight_decay_rate=1e-5, + epsilon1=1e-30, + epsilon2=1e-3): + """Create the Adafactor optimizer. + + Adafactor is described in https://arxiv.org/abs/1804.04235. + + Args: + learning_rate: float: trax-provided learning rate. + factored: boolean: whether to use factored second-moment estimator for 2d + variables. + multiply_by_parameter_scale: boolean: if True, then scale provided + learning_rate by parameter norm. if False, provided learning_rate is + absolute step size. + do_clipping: whether to clip gradients; if True, set clipping_theshold. + do_momentum: whether to use momentum; if True, set beta1. + beta1: a float value between 0 and 1, enables momentum and uses extra + memory if nonzero! Off by default. + decay_rate: float: controls second-moment exponential decay schedule. + clipping_threshold: an optional float >= 1, if None no update clipping. + weight_decay_rate: rate at which to decay weights. + epsilon1: Regularization constant for squared gradient. + epsilon2: Regularization constant for parameter scale. + """ + # These 4 parameters are not configurable once the class is created. + self._factored = factored + self._multiply_by_parameter_scale = multiply_by_parameter_scale + self._do_clipping = do_clipping + self._do_momentum = do_momentum + # Dynamically configurable parameters will be passed to the update function. + super(Adafactor, self).__init__( + learning_rate=learning_rate, + beta1=beta1, + decay_rate=decay_rate, + clipping_threshold=clipping_threshold, + weight_decay_rate=weight_decay_rate, + epsilon1=epsilon1, + epsilon2=epsilon2, + ) + + @staticmethod + def _decay_rate_pow(i, exponent=0.8): + """Default Adafactor second-moment decay schedule.""" + t = np.array(i, np.float32) + 1.0 + return 1.0 - t**(-exponent) + + def init(self, params): + shape = params.shape + slots = [] + if self._factored and len(shape) >= 2: + v_row = np.zeros(shape[:-1], dtype=np.float32) + v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32) + slots.extend([v_row, v_col]) + else: + v = np.zeros_like(params) + slots.append(v) + if self._do_momentum: + m = np.zeros_like(params) + slots.append(m) + return slots + + def update(self, step, grads, params, slots, opt_params): + updates = [] + learning_rate = opt_params["learning_rate"] + beta1 = opt_params["beta1"] + decay_rate = opt_params["decay_rate"] + clipping_threshold = opt_params["clipping_threshold"] + weight_decay_rate = opt_params["weight_decay_rate"] + epsilon1 = opt_params["epsilon1"] + epsilon2 = opt_params["epsilon2"] + decay_rate = self._decay_rate_pow(step, exponent=decay_rate) + update_scale = learning_rate + if self._multiply_by_parameter_scale: + update_scale *= np.maximum( + np.sqrt(np.mean(params * params)), epsilon2) + mixing_rate = 1.0 - decay_rate + + grads_sqr = grads * grads + epsilon1 + if self._factored and len(params.shape) >= 2: + v_row = slots.pop(0) + v_col = slots.pop(0) + new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) + new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) + updates.extend([new_v_row, new_v_col]) + row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) + row_factor = (new_v_row / row_col_mean)**-0.5 + col_factor = (new_v_col)**-0.5 + y = ( + grads * np.expand_dims(row_factor, axis=-1) * + np.expand_dims(col_factor, axis=-2)) + else: + v = slots.pop(0) + new_v = decay_rate * v + mixing_rate * grads_sqr + updates.append(new_v) + y = grads * (new_v)**-0.5 + + if self._do_clipping: + clipping_denom = ( + np.maximum(1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) + y /= clipping_denom + + subtrahend = update_scale * y + if self._do_momentum: + m = slots.pop(0) + new_m = beta1 * m + (1.0 - beta1) * subtrahend + subtrahend = new_m + updates.append(new_m) + + new_params = (1 - weight_decay_rate) * params - subtrahend + # TODO(lukaszkaiser): why is the astype needed here? Check and correct. + return new_params.astype(params.dtype), updates + + +class SM3(Optimizer): + """SM3 optimizer.""" + + def __init__(self, learning_rate, momentum=0.9): # pylint: disable=useless-super-delegation + """Create the SM3 optimizer. + + Memory-Efficient Adaptive Optimization for Large-Scale Learning. + https://arxiv.org/abs/1901.11150 + + Args: + learning_rate: a postitive scalar value for the initial learning rate. + momentum: optional, a positive scalar value for momentum + """ + super(SM3, self).__init__( + learning_rate=learning_rate, + momentum=momentum, + ) + + def init(self, params): + vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape] + return (np.zeros_like(params), vs) + + def _update_diagonal(self, grads, params, m, v, opt_params): + learning_rate = opt_params["learning_rate"] + momentum = opt_params["momentum"] + v[0] += grads * grads + preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]), + np.zeros_like(v[0])) + preconditioned_grads = preconditioner * grads + m = (1 - momentum) * preconditioned_grads + momentum * m + params = params - (learning_rate * m).astype(params.dtype) + return params, (m, v) + + def _expanded_shape(self, shape, axis): + # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. + # For eg: i = 1 returns [1, N, 1]. + rank = len(shape) + return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) + + def _minimum(self, tensor_list): + minimum = tensor_list[0] + for i in range(1, len(tensor_list)): + minimum = np.minimum(minimum, tensor_list[i]) + return minimum + + def _update_sketched(self, grads, params, m, v, opt_params): + """Update for higher-rank parameters.""" + learning_rate = opt_params["learning_rate"] + momentum = opt_params["momentum"] + shape = params.shape + rank = len(shape) + reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i)) + for i in range(rank)] + current_accumulator = self._minimum(reshaped_accumulators) + current_accumulator += grads * grads + accumulator_inv_sqrt = np.where(current_accumulator > 0.0, + 1.0 / np.sqrt(current_accumulator), + np.zeros_like(current_accumulator)) + preconditioned_gradient = grads * accumulator_inv_sqrt + m = (1.0 - momentum) * preconditioned_gradient + momentum * m + params = params - (learning_rate * m).astype(params.dtype) + for i in range(len(v)): + axes = list(range(int(i))) + list(range(int(i) + 1, rank)) + dim_accumulator = np.amax(current_accumulator, axis=axes) + v[i] = dim_accumulator + return params, (m, v) + + def update(self, step, grads, params, slots, opt_params): + del step + m, v = slots + shape = params.shape + rank = len(shape) + if rank > 1: + return self._update_sketched(grads, params, m, v, opt_params) + else: + return self._update_diagonal(grads, params, m, v, opt_params) diff --git a/trax/rl/__init__.py b/trax/rl/__init__.py new file mode 100644 index 000000000..3192156d7 --- /dev/null +++ b/trax/rl/__init__.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax RL library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin + +from trax.rl import simulated_env_problem + + +def configure_rl(*args, **kwargs): + kwargs["module"] = "trax.rl" + return gin.external_configurable(*args, **kwargs) + + +def configure_simulated_env_problem(*args, **kwargs): + kwargs["blacklist"] = [ + "batch_size", "observation_space", "action_space", "reward_range", + "discrete_rewards", "history_stream", "output_dir"] + return configure_rl(*args, **kwargs) + + +# pylint: disable=invalid-name +RawSimulatedEnvProblem = configure_simulated_env_problem( + simulated_env_problem.RawSimulatedEnvProblem) +SerializedSequenceSimulatedEnvProblem = configure_simulated_env_problem( + simulated_env_problem.SerializedSequenceSimulatedEnvProblem) + + +# pylint: disable=invalid-name +cartpole_done_fn = configure_rl(simulated_env_problem.cartpole_done_fn) +cartpole_reward_fn = configure_rl(simulated_env_problem.cartpole_reward_fn) +acrobot_done_fn = configure_rl(simulated_env_problem.acrobot_done_fn) +acrobot_reward_fn = configure_rl(simulated_env_problem.acrobot_reward_fn) +onlinetune_done_fn = configure_rl(simulated_env_problem.onlinetune_done_fn) +onlinetune_reward_fn = configure_rl(simulated_env_problem.onlinetune_reward_fn) diff --git a/trax/rl/base_trainer.py b/trax/rl/base_trainer.py new file mode 100644 index 000000000..d1c8360ce --- /dev/null +++ b/trax/rl/base_trainer.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for RL trainers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import logging +from tensorflow.io import gfile +from trax import utils + + +class BaseTrainer(object): + """Base class for RL trainers.""" + + def __init__( + self, train_env, eval_env, output_dir, + trajectory_dump_dir=None, trajectory_dump_min_count_per_shard=16, + async_mode=False, + ): + """Base class constructor. + + Args: + train_env: EnvProblem to use for training. Settable. + eval_env: EnvProblem to use for evaluation. Settable. + output_dir: Directory to save checkpoints and metrics to. + trajectory_dump_dir: Directory to dump trajectories to. Trajectories + are saved in shards of name .pkl under this directory. Settable. + trajectory_dump_min_count_per_shard: Minimum number of trajectories to + collect before dumping in a new shard. Sharding is for efficient + shuffling for model training in SimPLe. + async_mode: (bool) If True, this means we are in async mode and we read + trajectories from a location rather than interact with the environment. + """ + self.train_env = train_env + self.eval_env = eval_env + self._output_dir = output_dir + gfile.makedirs(self._output_dir) + self.trajectory_dump_dir = trajectory_dump_dir + self._trajectory_dump_min_count_per_shard = ( + trajectory_dump_min_count_per_shard) + self._trajectory_buffer = [] + self._async_mode = async_mode + + @property + def async_mode(self): + return self._async_mode + + @async_mode.setter + def async_mode(self, async_mode): + logging.vlog(1, "Changing async mode from %s to: %s", + self._async_mode, async_mode) + self._async_mode = async_mode + + @property + def epoch(self): + raise NotImplementedError + + def train_epoch(self, evaluate=True): + raise NotImplementedError + + def evaluate(self): + raise NotImplementedError + + def save(self): + raise NotImplementedError + + def flush_summaries(self): + raise NotImplementedError + + def dump_trajectories(self, force=False): + """Dumps trajectories in a new shard. + + Should be called at most once per epoch. + + Args: + force: (bool) Whether to complete unfinished trajectories and create + a new shard even if we have not reached the minimum size. + """ + pkl_module = utils.get_pickle_module() + if self.trajectory_dump_dir is None: + return + gfile.makedirs(self.trajectory_dump_dir) + + trajectories = self.train_env.trajectories + if force: + trajectories.complete_all_trajectories() + + # complete_all_trajectories() also adds trajectories that were just reset. + # We don't want them since they have just the initial observation and no + # actions, so we filter them out. + def has_any_action(trajectory): + return ( + trajectory.time_steps and trajectory.time_steps[0].action is not None) + self._trajectory_buffer.extend( + filter(has_any_action, trajectories.completed_trajectories)) + + trajectories.clear_completed_trajectories() + ready = ( + len(self._trajectory_buffer) >= + self._trajectory_dump_min_count_per_shard + ) + if ready or force: + shard_path = os.path.join( + self.trajectory_dump_dir, "{}.pkl".format(self.epoch)) + if gfile.exists(shard_path): + # Since we do an extra dump at the end of the training loop, we + # sometimes dump 2 times in the same epoch. When this happens, merge the + # two sets of trajectories. + with gfile.GFile(shard_path, "rb") as f: + self._trajectory_buffer = pkl_module.load(f) + self._trajectory_buffer + with gfile.GFile(shard_path, "wb") as f: + pkl_module.dump(self._trajectory_buffer, f) + self._trajectory_buffer = [] + + def training_loop(self, n_epochs, evaluate=True): + logging.info("Starting the RL training loop.") + for _ in range(self.epoch, n_epochs): + self.train_epoch(evaluate=evaluate) + self.dump_trajectories() + self.save() + self.dump_trajectories(force=True) + if evaluate: + self.evaluate() + self.flush_summaries() diff --git a/trax/rl/base_trainer_test.py b/trax/rl/base_trainer_test.py new file mode 100644 index 000000000..a04f26729 --- /dev/null +++ b/trax/rl/base_trainer_test.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.base_trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import cloudpickle as pickle +import numpy as np + +from tensor2tensor.envs import gym_env_problem +from tensorflow import test +from trax.rl import base_trainer + + +class FakeTrainer(base_trainer.BaseTrainer): + """Fake Trainer. + + Adds one complete and one incomplete trajectory every epoch. + """ + + def __init__(self, *args, **kwargs): + super(FakeTrainer, self).__init__(*args, **kwargs) + self._epoch = 0 + self._should_reset = True + + @property + def epoch(self): + return self._epoch + + def train_epoch(self): + trajectories = self.train_env.trajectories + if self._should_reset: + trajectories.reset(indices=np.arange(2), observations=np.zeros(2)) + self._should_reset = False + trajectories.step( + observations=np.zeros(2), + raw_rewards=np.zeros(2), + processed_rewards=np.zeros(2), + dones=np.array([False, True]), + actions=np.zeros(2), + ) + # Reset the trajectories that are done, as + # env_problem_utils.play_env_problem_with_policy does. + trajectories.reset(indices=np.array([1]), observations=np.zeros(1)) + self._epoch += 1 + + def evaluate(self): + pass + + def save(self): + pass + + def flush_summaries(self): + pass + + +class BaseTrainerTest(test.TestCase): + + def _make_trainer(self, min_count_per_shard): + train_env = gym_env_problem.GymEnvProblem( + base_env_name="Acrobot-v1", batch_size=2) + eval_env = gym_env_problem.GymEnvProblem( + base_env_name="Acrobot-v1", batch_size=1) + temp_dir = self.get_temp_dir() + return FakeTrainer( + train_env, eval_env, + output_dir=temp_dir, + trajectory_dump_dir=temp_dir, + trajectory_dump_min_count_per_shard=min_count_per_shard, + ) + + def _assert_no_shard_exists(self, trajectory_dir): + self.assertFalse(os.listdir(trajectory_dir)) + + def _assert_single_shard_exists_and_has_trajectories( + self, trajectory_dir, expected_trajectory_lengths): + shard_filenames = os.listdir(trajectory_dir) + self.assertEqual(len(shard_filenames), 1) + shard_path = os.path.join(trajectory_dir, shard_filenames[0]) + with open(shard_path, "rb") as f: + trajectories = pickle.load(f) + actual_trajectory_lengths = [ + len(trajectory.time_steps) for trajectory in trajectories] + self.assertEqual( + list(sorted(actual_trajectory_lengths)), + list(sorted(expected_trajectory_lengths)), + ) + + def test_dumps_full_shard(self): + trainer = self._make_trainer(min_count_per_shard=2) + trajectory_dir = self.get_temp_dir() + + # Add one complete trajectory to the buffer. Should not dump yet. + trainer.train_epoch() + trainer.dump_trajectories() + self._assert_no_shard_exists(trajectory_dir) + + # Add the second complete trajectory. Now we should dump. + trainer.train_epoch() + trainer.dump_trajectories() + self._assert_single_shard_exists_and_has_trajectories( + trajectory_dir, [2, 2]) + + def test_dumps_incomplete_trajectories_when_force_is_true(self): + trainer = self._make_trainer(min_count_per_shard=2) + trajectory_dir = self.get_temp_dir() + + # Add one complete and one incomplete trajectory to the buffer. Should dump. + trainer.train_epoch() + trainer.dump_trajectories(force=True) + self._assert_single_shard_exists_and_has_trajectories( + trajectory_dir, [2, 2]) + + def test_dumps_incomplete_shard_when_force_is_true(self): + trainer = self._make_trainer(min_count_per_shard=4) + trajectory_dir = self.get_temp_dir() + + # Add one complete and one incomplete trajectory to the buffer. Should dump, + # even though we don't have a full shard yet. + trainer.train_epoch() + trainer.dump_trajectories(force=True) + self._assert_single_shard_exists_and_has_trajectories( + trajectory_dir, [2, 2]) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/configs/acrobot.gin b/trax/rl/configs/acrobot.gin new file mode 100644 index 000000000..8cbffceb3 --- /dev/null +++ b/trax/rl/configs/acrobot.gin @@ -0,0 +1,44 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.rl.trainers + +# Parameters for FrameStackMLP: +# ============================================================================== +FrameStackMLP.n_frames = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 32 +PPO.target_kl = 1000 # Virtually infinite. +PPO.boundary = 512 +PPO.max_timestep = 512 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.0 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 1 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.FrameStackMLP + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "Acrobot-v1" +train_rl.n_epochs = 40000 diff --git a/trax/rl/configs/acrobot.ginE b/trax/rl/configs/acrobot.ginE new file mode 100644 index 000000000..e122270ae --- /dev/null +++ b/trax/rl/configs/acrobot.ginE @@ -0,0 +1,30 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.rl.trainers + +# Parameters for FrameStackMLP: +# ============================================================================== +FrameStackMLP.n_frames = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 32 +PPO.target_kl = 1000 # Virtually infinite. +PPO.boundary = 512 +PPO.max_timestep = 512 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.0 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 1 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.FrameStackMLP + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "Acrobot-v1" +train_rl.n_epochs = 40000 diff --git a/trax/rl/configs/acrobot_transformer.gin b/trax/rl/configs/acrobot_transformer.gin new file mode 100644 index 000000000..e28a302c9 --- /dev/null +++ b/trax/rl/configs/acrobot_transformer.gin @@ -0,0 +1,48 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.rl.trainers + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.1 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 32 +PPO.target_kl = 1000 # Virtually infinite. +PPO.boundary = 512 +PPO.max_timestep = 512 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.0 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = None +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.TransformerDecoder + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "Acrobot-v1" +train_rl.n_epochs = 40000 diff --git a/trax/rl/configs/acrobot_transformer.ginE b/trax/rl/configs/acrobot_transformer.ginE new file mode 100644 index 000000000..16343023d --- /dev/null +++ b/trax/rl/configs/acrobot_transformer.ginE @@ -0,0 +1,34 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.rl.trainers + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.1 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 32 +PPO.target_kl = 1000 # Virtually infinite. +PPO.boundary = 512 +PPO.max_timestep = 512 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.0 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = None +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.TransformerDecoder + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "Acrobot-v1" +train_rl.n_epochs = 40000 diff --git a/trax/rl/configs/atari.gin b/trax/rl/configs/atari.gin new file mode 100644 index 000000000..c7a454ea8 --- /dev/null +++ b/trax/rl/configs/atari.gin @@ -0,0 +1,44 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.rl.trainers + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 4 +PPO.target_kl = 0.01 +PPO.boundary = 20 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 4 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.AtariCnn + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "PongNoFrameskip-v4" +train_rl.n_epochs = 40000 +train_rl.clip_rewards = True +train_rl.max_timestep = 10000 +train_rl.rendered_env = True +train_rl.resize_dims = (105, 80) diff --git a/trax/rl/configs/atari.ginE b/trax/rl/configs/atari.ginE new file mode 100644 index 000000000..e4c0ab2c1 --- /dev/null +++ b/trax/rl/configs/atari.ginE @@ -0,0 +1,30 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.rl.trainers + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 4 +PPO.target_kl = 0.01 +PPO.boundary = 20 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 4 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.AtariCnn + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "PongNoFrameskip-v4" +train_rl.n_epochs = 40000 +train_rl.clip_rewards = True +train_rl.max_timestep = 10000 +train_rl.rendered_env = True +train_rl.resize_dims = (105, 80) diff --git a/trax/rl/configs/atari_regression_test.gin b/trax/rl/configs/atari_regression_test.gin new file mode 100644 index 000000000..6c84b01e8 --- /dev/null +++ b/trax/rl/configs/atari_regression_test.gin @@ -0,0 +1,44 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.rl.trainers + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 30 +PPO.target_kl = 0.01 +PPO.boundary = 20 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 4 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.AtariCnn + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "PongNoFrameskip-v4" +train_rl.n_epochs = 4000 +train_rl.clip_rewards = True +train_rl.max_timestep = 10000 +train_rl.rendered_env = True +train_rl.resize_dims = (105, 80) diff --git a/trax/rl/configs/atari_regression_test.ginE b/trax/rl/configs/atari_regression_test.ginE new file mode 100644 index 000000000..54a96d953 --- /dev/null +++ b/trax/rl/configs/atari_regression_test.ginE @@ -0,0 +1,30 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.rl.trainers + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 30 +PPO.target_kl = 0.01 +PPO.boundary = 20 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 20000 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.eval_every_n = 500 +PPO.done_frac_for_policy_save = 0.9 +PPO.n_evals = 16 +PPO.len_history_for_policy = 4 +PPO.eval_temperatures = (1.0, 0.5) +PPO.policy_and_value_model = @trax.models.AtariCnn + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "PongNoFrameskip-v4" +train_rl.n_epochs = 4000 +train_rl.clip_rewards = True +train_rl.max_timestep = 10000 +train_rl.rendered_env = True +train_rl.resize_dims = (105, 80) diff --git a/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.gin b/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.gin new file mode 100644 index 000000000..925d1536b --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.gin @@ -0,0 +1,119 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl +import trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 16 +batch_fn.max_eval_length = 12288 # 64 * 64 * 3 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet64_gen_flat_rev' +inputs.input_name = 'targets' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 + +# Parameters for MergedHashedCausalAttention: +# ============================================================================== +MergedHashedCausalAttention.dropout = 0.0 +MergedHashedCausalAttention.n_bins = 16 +MergedHashedCausalAttention.bin_by_time = True +MergedMultiHashedCausalAttention.one_rng = False + +# Parameters for MergedMultiHashedCausalAttention: +# ============================================================================== +MergedMultiHashedCausalAttention.dropout = 0.0 +MergedMultiHashedCausalAttention.n_bins = 64 +MergedMultiHashedCausalAttention.n_hashes = 2 +MergedMultiHashedCausalAttention.n_buckets_per_bin = 2 +MergedMultiHashedCausalAttention.bin_by_time = False +MergedMultiHashedCausalAttention.one_rng = False +MergedMultiHashedCausalAttention.drop_for_hash_rate = 0.1 +MergedMultiHashedCausalAttention.hard_k = 32 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.MergedMultiHashedCausalAttention +TransformerLM.d_attention_key = 64 +TransformerLM.d_attention_value = 64 +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = 12288 # 64 * 64 * 3 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 3 +TransformerLM.share_kv = True +TransformerLM.vocab_size = 256 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + + "dropout_attention_2": "dropout_attention_final", + "dropout_ff_middle_2": "dropout_ff_middle_final", + "dropout_ff_final_2": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.train_steps = 150 +OnlineTuneEnv.eval_steps = 2 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) diff --git a/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.ginE b/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.ginE new file mode 100644 index 000000000..3aeb46b66 --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_imagenet64_16gb.ginE @@ -0,0 +1,105 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 1 +batch_fn.eval_batch_size = 16 +batch_fn.max_eval_length = 12288 # 64 * 64 * 3 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_image_imagenet64_gen_flat_rev' +inputs.input_name = 'targets' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for MemoryEfficientCausalAttention: +# ============================================================================== +MemoryEfficientCausalAttention.dropout = 0.0 +MemoryEfficientCausalAttention.loop_stride = 512 + +# Parameters for MergedHashedCausalAttention: +# ============================================================================== +MergedHashedCausalAttention.dropout = 0.0 +MergedHashedCausalAttention.n_bins = 16 +MergedHashedCausalAttention.bin_by_time = True +MergedMultiHashedCausalAttention.one_rng = False + +# Parameters for MergedMultiHashedCausalAttention: +# ============================================================================== +MergedMultiHashedCausalAttention.dropout = 0.0 +MergedMultiHashedCausalAttention.n_bins = 64 +MergedMultiHashedCausalAttention.n_hashes = 2 +MergedMultiHashedCausalAttention.n_buckets_per_bin = 2 +MergedMultiHashedCausalAttention.bin_by_time = False +MergedMultiHashedCausalAttention.one_rng = False +MergedMultiHashedCausalAttention.drop_for_hash_rate = 0.1 +MergedMultiHashedCausalAttention.hard_k = 32 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.MergedMultiHashedCausalAttention +TransformerLM.d_attention_key = 64 +TransformerLM.d_attention_value = 64 +TransformerLM.d_model = 1024 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.0 +TransformerLM.max_len = 12288 # 64 * 64 * 3 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 4 +TransformerLM.n_layers = 3 +TransformerLM.share_kv = True +TransformerLM.vocab_size = 256 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + + "dropout_attention_2": "dropout_attention_final", + "dropout_ff_middle_2": "dropout_ff_middle_final", + "dropout_ff_final_2": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.train_steps = 150 +OnlineTuneEnv.eval_steps = 2 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) diff --git a/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.gin b/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.gin new file mode 100644 index 000000000..983b4bd51 --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.gin @@ -0,0 +1,113 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl +import trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 1 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.ginE b/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.ginE new file mode 100644 index 000000000..c765da790 --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_lm1b_16gb.ginE @@ -0,0 +1,99 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.eval_batch_size = 256 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_lm1b32k' +inputs.input_name = 'targets' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 32000 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 1 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.gin b/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.gin new file mode 100644 index 000000000..a93b1fb40 --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.gin @@ -0,0 +1,110 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl +import trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.bucket_length = 64 +batch_fn.max_eval_length = 512 +batch_fn.buckets_include_inputs_in_length = True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 255 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 1 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.has_weights = True +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.ginE b/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.ginE new file mode 100644 index 000000000..f4cef2897 --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_lm_wmt_ende_16gb.ginE @@ -0,0 +1,96 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 128 +batch_fn.eval_batch_size = 128 +batch_fn.bucket_length = 64 +batch_fn.max_eval_length = 512 +batch_fn.buckets_include_inputs_in_length = True + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_translate_ende_wmt32k' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.wmt_concat_preprocess +wmt_concat_preprocess.max_length = 255 +wmt_concat_preprocess.max_eval_length = 255 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 33300 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 1 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.has_weights = True +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_transformer_ptb_16gb.gin b/trax/rl/configs/env_online_tune_transformer_ptb_16gb.gin new file mode 100644 index 000000000..4945fbceb --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_ptb_16gb.gin @@ -0,0 +1,108 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl +import trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 512 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_ptb10k' +inputs.input_name = 'targets' + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 10240 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 200 +OnlineTuneEnv.eval_steps = 2 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_transformer_ptb_16gb.ginE b/trax/rl/configs/env_online_tune_transformer_ptb_16gb.ginE new file mode 100644 index 000000000..34e42e9cb --- /dev/null +++ b/trax/rl/configs/env_online_tune_transformer_ptb_16gb.ginE @@ -0,0 +1,94 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 64 +batch_fn.eval_batch_size = 512 +batch_fn.max_eval_length = 2048 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 't2t_languagemodel_ptb10k' +inputs.input_name = 'targets' + +# Parameters for preprocess_fun: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess +lm1b_preprocess.max_target_length = 512 +lm1b_preprocess.max_eval_target_length = 2048 + +# Parameters for DotProductCausalAttention: +# ============================================================================== +DotProductCausalAttention.dropout = 0.1 + +# Parameters for TransformerLM: +# ============================================================================== +TransformerLM.attention_type = @trax.layers.DotProductCausalAttention +TransformerLM.d_model = 512 +TransformerLM.d_ff = 2048 +TransformerLM.dropout = 0.1 +TransformerLM.max_len = 2048 +TransformerLM.mode = 'train' +TransformerLM.n_heads = 8 +TransformerLM.n_layers = 6 +TransformerLM.vocab_size = 10240 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.TransformerLM +OnlineTuneEnv.optimizer = @trax.optimizers.Adafactor +OnlineTuneEnv.train_steps = 200 +OnlineTuneEnv.eval_steps = 2 +OnlineTuneEnv.env_steps = 100 +OnlineTuneEnv.control_configs = ( + ("learning_rate", 1e-3, (1e-9, 1e-2), False), + ("weight_decay_rate", 1e-5, (1e-9, 1e-3), False), + + ("dropout_embedding", 0.1, (0.0, 0.9), True), + + ("dropout_attention_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_initial", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_initial", 0.1, (0.0, 0.9), True), + + ("dropout_attention_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_middle", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_middle", 0.1, (0.0, 0.9), True), + + ("dropout_attention_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_middle_final", 0.1, (0.0, 0.9), True), + ("dropout_ff_final_final", 0.1, (0.0, 0.9), True), +) +OnlineTuneEnv.nontrainable_param_map = { + # "dropout_{layer_type}_{block_index}": "dropout_{layer_type}_{block_group}" + + "dropout_attention_0": "dropout_attention_initial", + "dropout_ff_middle_0": "dropout_ff_middle_initial", + "dropout_ff_final_0": "dropout_ff_final_initial", + + "dropout_attention_1": "dropout_attention_middle", + "dropout_ff_middle_1": "dropout_ff_middle_middle", + "dropout_ff_final_1": "dropout_ff_final_middle", + "dropout_attention_2": "dropout_attention_middle", + "dropout_ff_middle_2": "dropout_ff_middle_middle", + "dropout_ff_final_2": "dropout_ff_final_middle", + "dropout_attention_3": "dropout_attention_middle", + "dropout_ff_middle_3": "dropout_ff_middle_middle", + "dropout_ff_final_3": "dropout_ff_final_middle", + "dropout_attention_4": "dropout_attention_middle", + "dropout_ff_middle_4": "dropout_ff_middle_middle", + "dropout_ff_final_4": "dropout_ff_final_middle", + + "dropout_attention_5": "dropout_attention_final", + "dropout_ff_middle_5": "dropout_ff_middle_final", + "dropout_ff_final_5": "dropout_ff_final_final", +} +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.observation_range = (0.0, 10.0) +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.mask_id = 0 diff --git a/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.gin b/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.gin new file mode 100644 index 000000000..a84d370f2 --- /dev/null +++ b/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.gin @@ -0,0 +1,66 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl +import trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + +# Parameters for shuffle_and_batch_data: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_augmentation_preprocess + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 4 +WideResnet.n_output_classes = 10 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.WideResnet +OnlineTuneEnv.optimizer = @trax.optimizers.Momentum +OnlineTuneEnv.control_configs = ( + ("learning_rate", 0.1, (1e-9, 10.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 0.1), False), + ("mass", 0.9, (0.0, 0.99), True), +) +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.train_steps = 100 +OnlineTuneEnv.eval_steps = 10 +OnlineTuneEnv.env_steps = 100 diff --git a/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.ginE b/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.ginE new file mode 100644 index 000000000..9b8a8e6d5 --- /dev/null +++ b/trax/rl/configs/env_online_tune_wide_resnet_cifar10_8gb.ginE @@ -0,0 +1,52 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.envs + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size_per_device = 256 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 512 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for train_and_eval_dataset: +# ============================================================================== +train_and_eval_dataset.eval_holdout_size = 0.05 +train_and_eval_dataset.eval_shuffle_files = True + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + +# Parameters for shuffle_and_batch_data: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_augmentation_preprocess + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 4 +WideResnet.n_output_classes = 10 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.WideResnet +OnlineTuneEnv.optimizer = @trax.optimizers.Momentum +OnlineTuneEnv.control_configs = ( + ("learning_rate", 0.1, (1e-9, 10.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 0.1), False), + ("mass", 0.9, (0.0, 0.99), True), +) +OnlineTuneEnv.include_controls_in_observation = False +OnlineTuneEnv.action_multipliers = (0.5, 0.8, 0.95, 1.0, 1.05, 1.25, 2.0) +OnlineTuneEnv.train_steps = 100 +OnlineTuneEnv.eval_steps = 10 +OnlineTuneEnv.env_steps = 100 diff --git a/trax/rl/configs/ppo_online_tune.gin b/trax/rl/configs/ppo_online_tune.gin new file mode 100644 index 000000000..a1b13ee03 --- /dev/null +++ b/trax/rl/configs/ppo_online_tune.gin @@ -0,0 +1,51 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.optimizers +import trax.rl.trainers + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.target_kl = 0.1 +PPO.boundary = 128 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 128 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam diff --git a/trax/rl/configs/ppo_online_tune.ginE b/trax/rl/configs/ppo_online_tune.ginE new file mode 100644 index 000000000..bfaa6e4fe --- /dev/null +++ b/trax/rl/configs/ppo_online_tune.ginE @@ -0,0 +1,37 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl.trainers + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.target_kl = 0.1 +PPO.boundary = 128 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 128 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam diff --git a/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.gin b/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.gin new file mode 100644 index 000000000..e7d5c07e4 --- /dev/null +++ b/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.gin @@ -0,0 +1,93 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.inputs +import trax.models +import trax.optimizers +import trax.rl.envs +import trax.rl.trainers + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + +# Parameters for shuffle_and_batch_data: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_no_augmentation_preprocess + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 3 +WideResnet.n_output_classes = 10 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.WideResnet +OnlineTuneEnv.optimizer = @trax.optimizers.Momentum +OnlineTuneEnv.start_lr = 0.01 +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 50 +OnlineTuneEnv.env_steps = 100 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.target_kl = 0.1 +PPO.boundary = 128 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 128 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "OnlineTuneEnv-v0" +train_rl.n_epochs = 1000 diff --git a/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.ginE b/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.ginE new file mode 100644 index 000000000..040da7703 --- /dev/null +++ b/trax/rl/configs/ppo_online_tune_wide_resnet_cifar10.ginE @@ -0,0 +1,79 @@ +import tensor2tensor.trax.inputs +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.rl.envs +import tensor2tensor.trax.rl.trainers + +# Parameters for batch_fn: +# ============================================================================== +batch_fn.batch_size = 32 +batch_fn.bucket_length = 32 +batch_fn.buckets = None +batch_fn.eval_batch_size = 32 + +# Parameters for inputs: +# ============================================================================== +inputs.data_dir = None +inputs.dataset_name = 'cifar10' + +# Parameters for Momentum: +# ============================================================================== +Momentum.mass = 0.9 + +# Parameters for shuffle_and_batch_data: +# ============================================================================== +shuffle_and_batch_data.preprocess_fun = @trax.inputs.cifar10_no_augmentation_preprocess + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.dropout = 0.0 +TransformerDecoder.n_heads = 2 +TransformerDecoder.n_layers = 1 + +# Parameters for WideResnet: +# ============================================================================== +WideResnet.widen_factor = 10 +WideResnet.n_blocks = 3 +WideResnet.n_output_classes = 10 + +# Parameters for OnlineTuneEnv: +# ============================================================================== +OnlineTuneEnv.inputs = @trax.inputs.inputs +OnlineTuneEnv.model = @trax.models.WideResnet +OnlineTuneEnv.optimizer = @trax.optimizers.Momentum +OnlineTuneEnv.start_lr = 0.01 +OnlineTuneEnv.train_steps = 500 +OnlineTuneEnv.eval_steps = 50 +OnlineTuneEnv.env_steps = 100 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.target_kl = 0.1 +PPO.boundary = 128 +PPO.max_timestep = 128 +PPO.max_timestep_eval = 128 +PPO.random_seed = None +PPO.gamma = 0.99 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.01 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam + +# Parameters for train_rl: +# ============================================================================== +train_rl.env_name = "OnlineTuneEnv-v0" +train_rl.n_epochs = 1000 diff --git a/trax/rl/configs/simple_online_tune.gin b/trax/rl/configs/simple_online_tune.gin new file mode 100644 index 000000000..63ba1874d --- /dev/null +++ b/trax/rl/configs/simple_online_tune.gin @@ -0,0 +1,109 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.optimizers +import trax.trainer_lib +import trax.rl +import trax.rl.space_serializer +import trax.rl.trainers + +# Parameters for BoxSpaceSerializer: +# ============================================================================== +BoxSpaceSerializer.precision = 2 + +# Parameters for MultifactorSchedule: +# ============================================================================== +world_model/MultifactorSchedule.constant = 1.0 +world_model/MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +world_model/MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.n_layers = 2 +TransformerDecoder.n_heads = 2 +TransformerDecoder.dropout = 0.0 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.optimizer_batch_size = 128 +PPO.target_kl = 0.1 +PPO.boundary = 100 +PPO.max_timestep = 100 +PPO.max_timestep_eval = 100 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam +PPO.trajectory_dump_min_count_per_shard = 8 +PPO.print_every_optimizer_steps = 1 + +## Parameters for TimeBinCausalAttention: +## ============================================================================== +world_model/TimeBinCausalAttention.dropout = 0.1 +world_model/TimeBinCausalAttention.bin_length = 512 + +# Parameters for SerializedSequenceSimulatedEnvProblem: +# ============================================================================== +SerializedSequenceSimulatedEnvProblem.model = @world_model/trax.models.TransformerLM +SerializedSequenceSimulatedEnvProblem.reward_fn = @trax.rl.onlinetune_reward_fn +SerializedSequenceSimulatedEnvProblem.done_fn = @trax.rl.onlinetune_done_fn +SerializedSequenceSimulatedEnvProblem.vocab_size = 128 +SerializedSequenceSimulatedEnvProblem.max_trajectory_length = 101 +SerializedSequenceSimulatedEnvProblem.significance_decay = 0.8 + +# Parameters for SimPLe: +# ============================================================================== +SimPLe.policy_trainer_class = @trax.rl.trainers.PPO +SimPLe.n_real_epochs = 1 +SimPLe.n_model_initial_train_steps = 50000 +SimPLe.n_model_train_steps_per_epoch = 10000 +SimPLe.model_train_batch_size = 64 +SimPLe.simulated_env_problem_class = @trax.rl.SerializedSequenceSimulatedEnvProblem +SimPLe.simulated_batch_size = 128 +SimPLe.n_simulated_epochs = 50 +SimPLe.initial_trajectory_mix_prob = 0.9 +SimPLe.init_policy_from_world_model = False + +# Parameters for TransformerLM: +# ============================================================================== +world_model/TransformerLM.attention_type = @world_model/trax.layers.TimeBinCausalAttention +world_model/TransformerLM.d_model = 256 +world_model/TransformerLM.d_ff = 512 +world_model/TransformerLM.n_layers = 3 +world_model/TransformerLM.n_heads = 4 +world_model/TransformerLM.dropout = 0.1 +world_model/TransformerLM.max_len = 2048 + +# Parameters for train: +# ============================================================================== +world_model/train.eval_frequency = 1000 +world_model/train.optimizer = @trax.optimizers.Adafactor diff --git a/trax/rl/configs/simple_online_tune.ginE b/trax/rl/configs/simple_online_tune.ginE new file mode 100644 index 000000000..ac67788f0 --- /dev/null +++ b/trax/rl/configs/simple_online_tune.ginE @@ -0,0 +1,95 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.space_serializer +import tensor2tensor.trax.rl.trainers + +# Parameters for BoxSpaceSerializer: +# ============================================================================== +BoxSpaceSerializer.precision = 2 + +# Parameters for MultifactorSchedule: +# ============================================================================== +world_model/MultifactorSchedule.constant = 1.0 +world_model/MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +world_model/MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.d_model = 64 +TransformerDecoder.d_ff = 128 +TransformerDecoder.n_layers = 2 +TransformerDecoder.n_heads = 2 +TransformerDecoder.dropout = 0.0 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 10 +PPO.optimizer_batch_size = 128 +PPO.target_kl = 0.1 +PPO.boundary = 100 +PPO.max_timestep = 100 +PPO.max_timestep_eval = 100 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam +PPO.trajectory_dump_min_count_per_shard = 8 +PPO.print_every_optimizer_steps = 1 + +## Parameters for TimeBinCausalAttention: +## ============================================================================== +world_model/TimeBinCausalAttention.dropout = 0.1 +world_model/TimeBinCausalAttention.bin_length = 512 + +# Parameters for SerializedSequenceSimulatedEnvProblem: +# ============================================================================== +SerializedSequenceSimulatedEnvProblem.model = @world_model/trax.models.TransformerLM +SerializedSequenceSimulatedEnvProblem.reward_fn = @trax.rl.onlinetune_reward_fn +SerializedSequenceSimulatedEnvProblem.done_fn = @trax.rl.onlinetune_done_fn +SerializedSequenceSimulatedEnvProblem.vocab_size = 128 +SerializedSequenceSimulatedEnvProblem.max_trajectory_length = 101 +SerializedSequenceSimulatedEnvProblem.significance_decay = 0.8 + +# Parameters for SimPLe: +# ============================================================================== +SimPLe.policy_trainer_class = @trax.rl.trainers.PPO +SimPLe.n_real_epochs = 1 +SimPLe.n_model_initial_train_steps = 50000 +SimPLe.n_model_train_steps_per_epoch = 10000 +SimPLe.model_train_batch_size = 64 +SimPLe.simulated_env_problem_class = @trax.rl.SerializedSequenceSimulatedEnvProblem +SimPLe.simulated_batch_size = 128 +SimPLe.n_simulated_epochs = 50 +SimPLe.initial_trajectory_mix_prob = 0.9 +SimPLe.init_policy_from_world_model = False + +# Parameters for TransformerLM: +# ============================================================================== +world_model/TransformerLM.attention_type = @world_model/trax.layers.TimeBinCausalAttention +world_model/TransformerLM.d_model = 256 +world_model/TransformerLM.d_ff = 512 +world_model/TransformerLM.n_layers = 3 +world_model/TransformerLM.n_heads = 4 +world_model/TransformerLM.dropout = 0.1 +world_model/TransformerLM.max_len = 2048 + +# Parameters for train: +# ============================================================================== +world_model/train.eval_frequency = 1000 +world_model/train.optimizer = @trax.optimizers.Adafactor diff --git a/trax/rl/configs/simple_online_tune_serialized.gin b/trax/rl/configs/simple_online_tune_serialized.gin new file mode 100644 index 000000000..e72915d3b --- /dev/null +++ b/trax/rl/configs/simple_online_tune_serialized.gin @@ -0,0 +1,114 @@ +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import trax.models +import trax.optimizers +import trax.trainer_lib +import trax.rl +import trax.rl.space_serializer +import trax.rl.trainers + +# Parameters for BoxSpaceSerializer: +# ============================================================================== +BoxSpaceSerializer.precision = 2 + +# Parameters for MultifactorSchedule: +# ============================================================================== +world_model/MultifactorSchedule.constant = 1.0 +world_model/MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +world_model/MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.attention_type = @policy/trax.layers.TimeBinCausalAttention +TransformerDecoder.d_model = 256 +TransformerDecoder.d_ff = 512 +TransformerDecoder.n_layers = 3 +TransformerDecoder.n_heads = 4 +TransformerDecoder.dropout = 0.0 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 20 +PPO.optimizer_batch_size = 64 +PPO.target_kl = 0.1 +PPO.boundary = 100 +PPO.max_timestep = 100 +PPO.max_timestep_eval = 100 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam +PPO.policy_and_value_vocab_size = 128 +PPO.trajectory_dump_min_count_per_shard = 8 +PPO.print_every_optimizer_steps = 1 + +## Parameters for TimeBinCausalAttention: +## ============================================================================== +world_model/TimeBinCausalAttention.dropout = 0.1 +world_model/TimeBinCausalAttention.bin_length = 512 + +policy/TimeBinCausalAttention.dropout = 0.0 +policy/TimeBinCausalAttention.bin_length = 512 + +# Parameters for SerializedSequenceSimulatedEnvProblem: +# ============================================================================== +SerializedSequenceSimulatedEnvProblem.model = @world_model/trax.models.TransformerLM +SerializedSequenceSimulatedEnvProblem.reward_fn = @trax.rl.onlinetune_reward_fn +SerializedSequenceSimulatedEnvProblem.done_fn = @trax.rl.onlinetune_done_fn +SerializedSequenceSimulatedEnvProblem.vocab_size = 128 +SerializedSequenceSimulatedEnvProblem.max_trajectory_length = 101 +SerializedSequenceSimulatedEnvProblem.significance_decay = 0.8 + +# Parameters for SimPLe: +# ============================================================================== +SimPLe.policy_trainer_class = @trax.rl.trainers.PPO +SimPLe.n_real_epochs = 1 +SimPLe.n_model_initial_train_steps = 50000 +SimPLe.n_model_train_steps_per_epoch = 10000 +SimPLe.model_train_batch_size = 64 +SimPLe.simulated_env_problem_class = @trax.rl.SerializedSequenceSimulatedEnvProblem +SimPLe.simulated_batch_size = 128 +SimPLe.n_simulated_epochs = 50 +SimPLe.initial_trajectory_mix_prob = 0.9 +SimPLe.init_policy_from_world_model = True + +# Parameters for TransformerLM: +# ============================================================================== +world_model/TransformerLM.attention_type = @world_model/trax.layers.TimeBinCausalAttention +world_model/TransformerLM.d_model = 256 +world_model/TransformerLM.d_ff = 512 +world_model/TransformerLM.n_layers = 3 +world_model/TransformerLM.n_heads = 4 +world_model/TransformerLM.dropout = 0.1 +world_model/TransformerLM.max_len = 2048 + +# Parameters for train: +# ============================================================================== +world_model/train.eval_frequency = 1000 +world_model/train.optimizer = @trax.optimizers.Adafactor diff --git a/trax/rl/configs/simple_online_tune_serialized.ginE b/trax/rl/configs/simple_online_tune_serialized.ginE new file mode 100644 index 000000000..181efb64c --- /dev/null +++ b/trax/rl/configs/simple_online_tune_serialized.ginE @@ -0,0 +1,100 @@ +import tensor2tensor.trax.models +import tensor2tensor.trax.optimizers +import tensor2tensor.trax.trax +import tensor2tensor.trax.rl +import tensor2tensor.trax.rl.space_serializer +import tensor2tensor.trax.rl.trainers + +# Parameters for BoxSpaceSerializer: +# ============================================================================== +BoxSpaceSerializer.precision = 2 + +# Parameters for MultifactorSchedule: +# ============================================================================== +world_model/MultifactorSchedule.constant = 1.0 +world_model/MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' +world_model/MultifactorSchedule.warmup_steps = 10000 + +# Parameters for Adam: +# ============================================================================== +Adam.learning_rate = 1e-3 +Adam.b1 = 0.9 +Adam.b2 = 0.999 +Adam.weight_decay_rate = 0.0 + +# Parameters for TransformerDecoder: +# ============================================================================== +TransformerDecoder.attention_type = @policy/trax.layers.TimeBinCausalAttention +TransformerDecoder.d_model = 256 +TransformerDecoder.d_ff = 512 +TransformerDecoder.n_layers = 3 +TransformerDecoder.n_heads = 4 +TransformerDecoder.dropout = 0.0 + +# Parameters for PPO: +# ============================================================================== +PPO.n_optimizer_steps = 20 +PPO.optimizer_batch_size = 64 +PPO.target_kl = 0.1 +PPO.boundary = 100 +PPO.max_timestep = 100 +PPO.max_timestep_eval = 100 +PPO.random_seed = None +PPO.gamma = 1.0 +PPO.lambda_ = 0.95 +PPO.c1 = 1.0 +PPO.c2 = 0.1 +PPO.done_frac_for_policy_save = 0 +PPO.len_history_for_policy = None +PPO.separate_eval = False +PPO.save_every_n = 1 +PPO.policy_and_value_model = @trax.models.TransformerDecoder +PPO.policy_and_value_optimizer = @trax.optimizers.Adam +PPO.policy_and_value_vocab_size = 128 +PPO.trajectory_dump_min_count_per_shard = 8 +PPO.print_every_optimizer_steps = 1 + +## Parameters for TimeBinCausalAttention: +## ============================================================================== +world_model/TimeBinCausalAttention.dropout = 0.1 +world_model/TimeBinCausalAttention.bin_length = 512 + +policy/TimeBinCausalAttention.dropout = 0.0 +policy/TimeBinCausalAttention.bin_length = 512 + +# Parameters for SerializedSequenceSimulatedEnvProblem: +# ============================================================================== +SerializedSequenceSimulatedEnvProblem.model = @world_model/trax.models.TransformerLM +SerializedSequenceSimulatedEnvProblem.reward_fn = @trax.rl.onlinetune_reward_fn +SerializedSequenceSimulatedEnvProblem.done_fn = @trax.rl.onlinetune_done_fn +SerializedSequenceSimulatedEnvProblem.vocab_size = 128 +SerializedSequenceSimulatedEnvProblem.max_trajectory_length = 101 +SerializedSequenceSimulatedEnvProblem.significance_decay = 0.8 + +# Parameters for SimPLe: +# ============================================================================== +SimPLe.policy_trainer_class = @trax.rl.trainers.PPO +SimPLe.n_real_epochs = 1 +SimPLe.n_model_initial_train_steps = 50000 +SimPLe.n_model_train_steps_per_epoch = 10000 +SimPLe.model_train_batch_size = 64 +SimPLe.simulated_env_problem_class = @trax.rl.SerializedSequenceSimulatedEnvProblem +SimPLe.simulated_batch_size = 128 +SimPLe.n_simulated_epochs = 50 +SimPLe.initial_trajectory_mix_prob = 0.9 +SimPLe.init_policy_from_world_model = True + +# Parameters for TransformerLM: +# ============================================================================== +world_model/TransformerLM.attention_type = @world_model/trax.layers.TimeBinCausalAttention +world_model/TransformerLM.d_model = 256 +world_model/TransformerLM.d_ff = 512 +world_model/TransformerLM.n_layers = 3 +world_model/TransformerLM.n_heads = 4 +world_model/TransformerLM.dropout = 0.1 +world_model/TransformerLM.max_len = 2048 + +# Parameters for train: +# ============================================================================== +world_model/train.eval_frequency = 1000 +world_model/train.optimizer = @trax.optimizers.Adafactor diff --git a/trax/rl/envs/__init__.py b/trax/rl/envs/__init__.py new file mode 100644 index 000000000..da8950638 --- /dev/null +++ b/trax/rl/envs/__init__.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Environments defined in RL.""" + +import gin +from gym.envs.registration import register + +from trax.rl.envs import online_tune_env + + +# Ginify and register in gym. +def configure_and_register_env(env_class): + register( + id="{}-v0".format(env_class.__name__), + entry_point="trax.rl.envs:{}".format(env_class.__name__), + ) + return gin.external_configurable(env_class, module="trax.rl.envs") + + +# pylint: disable=invalid-name +OnlineTuneEnv = configure_and_register_env(online_tune_env.OnlineTuneEnv) diff --git a/trax/rl/envs/async_trajectory_collector.py b/trax/rl/envs/async_trajectory_collector.py new file mode 100644 index 000000000..521b2a1f5 --- /dev/null +++ b/trax/rl/envs/async_trajectory_collector.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A trajectory collector that polls on policy files and keeps collecting trajectories.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import multiprocessing +import os +from absl import app +from absl import flags +from absl import logging +import gin +import jax +from jax.config import config +from tensor2tensor.envs import env_problem_utils +from tensor2tensor.rl.google import atari_utils # GOOGLE-INTERNAL: +import tensorflow as tf +from trax import rl # pylint: disable=unused-import +from trax.rl import envs as rl_envs # pylint: disable=unused-import +from trax.rl.envs import async_trajectory_collector_lib as async_lib + +FLAGS = flags.FLAGS + +flags.DEFINE_multi_string("config_file", None, + "Configuration file with parameters (.gin).") +flags.DEFINE_multi_string("config", None, + "Configuration parameters (gin string).") +flags.DEFINE_bool("use_tpu", False, "Whether we're running on TPU.") +flags.DEFINE_bool("xm", False, "Copy atari roms?") + +flags.DEFINE_bool( + "try_abort", True, + "Should we try to abort a trajectory collection if a newer " + "policy is available.") + +flags.DEFINE_string("output_dir", "", "Output dir.") +flags.DEFINE_string("envs_output_dir", "", "Output dir for the envs.") + +flags.DEFINE_boolean( + "jax_debug_nans", False, + "Setting to true will help to debug nans and disable jit.") +flags.DEFINE_boolean("disable_jit", False, "Setting to true will disable jit.") + +flags.DEFINE_boolean("parallelize_envs", False, + "If true, sets parallelism to number of cpu cores.") + +flags.DEFINE_integer("replica", 0, "Basically to append to trajectory name.") +flags.DEFINE_bool("enable_eager_execution", False, "") + +flags.DEFINE_integer( + "max_trajectories_to_collect", -1, + "-1 for infinite, otherwise whatever number was specified.") + + +# TODO(afrozm): This code snippet is strewn across many places, unify it. +def initialize_gin(): + gin_configs = FLAGS.config or [] + gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) + + +def get_output_dir(): + """Return output_dir.""" + output_dir = FLAGS.output_dir + return output_dir + + +def update_jax_config(): + """Update JAX config based on flags.""" + + if FLAGS.jax_debug_nans: + config.update("jax_debug_nans", True) + + if FLAGS.use_tpu: + config.update("jax_platform_name", "tpu") + else: + config.update("jax_platform_name", "gpu") + + +@gin.configurable(blacklist=[ + "output_dir", +]) +def create_envs_and_collect_trajectories( + output_dir, + env_name="OnlineTuneEnv-v0", + max_timestep=None, + clip_rewards=False, + rendered_env=False, + resize_dims=(105, 80), +): + """Creates the envs and continuously collects trajectories.""" + + + train_batch_size = 1 + eval_batch_size = 1 + + # TODO(pkozakowski): Find a better way to determine this. + train_env_kwargs = {} + eval_env_kwargs = {} + if "OnlineTuneEnv" in env_name: + envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, "envs") + train_env_output_dir = os.path.join(envs_output_dir, "train") + eval_env_output_dir = os.path.join(envs_output_dir, "eval") + train_env_kwargs = {"output_dir": train_env_output_dir} + eval_env_kwargs = {"output_dir": eval_env_output_dir} + + if "ClientEnv" in env_name: + train_env_kwargs["per_env_kwargs"] = [{ + "remote_env_address": os.path.join(FLAGS.train_server_bns, str(replica)) + } for replica in range(train_batch_size)] + + eval_env_kwargs["per_env_kwargs"] = [{ + "remote_env_address": os.path.join(FLAGS.eval_server_bns, str(replica)) + } for replica in range(eval_batch_size)] + + parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 + train_parallelism = min(train_batch_size, parallelism) + eval_parallelism = min(eval_batch_size, parallelism) + + train_env = env_problem_utils.make_env( + batch_size=train_batch_size, + env_problem_name=env_name, + resize=rendered_env, + resize_dims=resize_dims, + max_timestep=max_timestep, + clip_rewards=clip_rewards, + parallelism=train_parallelism, + use_tpu=FLAGS.use_tpu, + **train_env_kwargs) + assert train_env + + eval_env = env_problem_utils.make_env( + batch_size=eval_batch_size, + env_problem_name=env_name, + resize=rendered_env, + resize_dims=resize_dims, + max_timestep=max_timestep, + clip_rewards=clip_rewards, + parallelism=eval_parallelism, + use_tpu=FLAGS.use_tpu, + **eval_env_kwargs) + assert eval_env + + def run_collect_loop(): + async_lib.continuously_collect_trajectories( + output_dir, + train_env, + eval_env, + trajectory_dump_dir=None, + env_id=FLAGS.replica, + try_abort=FLAGS.try_abort, + max_trajectories_to_collect=(None + if FLAGS.max_trajectories_to_collect < 0 + else FLAGS.max_trajectories_to_collect)) + + if FLAGS.jax_debug_nans or FLAGS.disable_jit: + with jax.disable_jit(): + run_collect_loop() + else: + run_collect_loop() + + +def main(argv): + del argv + + if FLAGS.enable_eager_execution: + tf.enable_eager_execution() + + logging.info("Initializing Gin.") + initialize_gin() + + logging.info("Update JAX config.") + update_jax_config() + + logging.info("Getting output_dir") + output_dir = get_output_dir() + logging.info("Got output_dir = %s", output_dir) + + logging.info("Starting Trajectory collection.") + create_envs_and_collect_trajectories(output_dir) + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/rl/envs/async_trajectory_collector_lib.py b/trax/rl/envs/async_trajectory_collector_lib.py new file mode 100644 index 000000000..935ab0489 --- /dev/null +++ b/trax/rl/envs/async_trajectory_collector_lib.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Uitlity functions for the async trajectory collector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import random +import time + +from absl import logging +from tensor2tensor.envs import trajectory +from tensorflow.io import gfile +from trax.rl import ppo +from trax.rl import trainers as rl_trainers + +LARGE_MAX_TRIES_FOR_POLICY_FILE = 100 + + +# TODO(afrozm): Is there a better way to poll for a file on CNS? +def get_newer_policy_model_file(output_dir, + min_epoch=-1, + sleep_time_secs=0.1, + max_sleep_time_secs=1.0, + max_tries=1, + wait_forever=False,): + """Gets a policy model file subject to availability and wait time.""" + + while max_tries or wait_forever: + max_tries -= 1 + policy_files = ppo.get_policy_model_files(output_dir) + + def do_wait(t): + time.sleep(t) + t *= 2 + return min(t, max_sleep_time_secs) + + # No policy files at all. + if not policy_files: + logging.info("There are no policy files in [%s], waiting for %s secs.", + output_dir, sleep_time_secs) + sleep_time_secs = do_wait(sleep_time_secs) + continue + + # Check if we have a newer epoch. + policy_file = policy_files[0] + epoch = ppo.get_epoch_from_policy_model_file(policy_file) + + # We don't - wait. + if epoch <= min_epoch: + logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.", epoch, + min_epoch, sleep_time_secs) + sleep_time_secs = do_wait(sleep_time_secs) + continue + + # We do have a new file, return it. + policy_file = policy_files[0] + epoch = ppo.get_epoch_from_policy_model_file(policy_file) + logging.info("Found epoch [%s] and policy file [%s]", epoch, policy_file) + return policy_file, epoch + + # Exhausted our waiting limit. + return None + + +def dump_trajectory(output_dir, epoch, env_id, temperature, random_string, + trajs): + """Write the trajectory to disk.""" + + assert 1 == len(trajs) + traj = trajs[0] + + trajectory_file_name = trajectory.TRAJECTORY_FILE_FORMAT.format( + epoch=epoch, env_id=env_id, temperature=temperature, r=random_string) + + with gfile.GFile(os.path.join(output_dir, trajectory_file_name), "w") as f: + trajectory.get_pickle_module().dump(traj, f) + + +def continuously_collect_trajectories(output_dir, + train_env, + eval_env, + trajectory_dump_dir=None, + env_id=None, + max_trajectories_to_collect=None, + try_abort=True): + """Instantiates a PPO trainer and collects trajectories.""" + + # Make the PPO trainer. + ppo_trainer = rl_trainers.PPO( + output_dir=output_dir, + train_env=train_env, + eval_env=eval_env, + trajectory_dump_dir=trajectory_dump_dir, + ) + + # TODO(afrozm): Update base_trainer interface to support SimPLe as well. + assert isinstance(ppo_trainer, rl_trainers.PPO) + + assert env_id is not None + + # Get an initial policy and wait a forever to get it if needed. + policy_and_epoch = get_newer_policy_model_file(output_dir, wait_forever=True) + assert policy_and_epoch + policy_file, epoch = policy_and_epoch + logging.info("Read initial policy for epoch [%s] -> [%s]", epoch, policy_file) + + # Returns immediately if there is a newer epoch available. + def is_newer_policy_file_available(epoch_, sleep_time_secs_=0.1): + return get_newer_policy_model_file( + output_dir, min_epoch=epoch_, sleep_time_secs=sleep_time_secs_) + + assert 1 == train_env.batch_size + assert 1 == eval_env.batch_size + + temperature = 1.0 + + trajectories_collected = 0 + + train_env_trajectory_dump_dir = os.path.join(output_dir, "trajectories/train") + eval_env_trajectory_dump_dir = os.path.join(output_dir, "trajectories/eval") + + gfile.makedirs(train_env_trajectory_dump_dir) + gfile.makedirs(eval_env_trajectory_dump_dir) + + while max_trajectories_to_collect is None or trajectories_collected < int( + max_trajectories_to_collect): + logging.info("Collecting a trajectory, trajectories_collected = %s", + trajectories_collected) + + # Abort function -- if something newever is available, then abort the + # current computation and reload. + + # Useful if env.step is long. + def long_abort_fn(): + # We want this to be as quick as possible. + return is_newer_policy_file_available(epoch, 0) is not None + + abort_fn = long_abort_fn if try_abort else None + + # Collect a training trajectory. + trajs, n_done, unused_timing_info, unused_model_state = ( + ppo_trainer.collect_trajectories(train=True, + temperature=temperature, + abort_fn=abort_fn, + raw_trajectory=True)) + + if trajs and n_done > 0: + assert 1 == n_done + trajectories_collected += n_done + + # Write the trajectory down. + logging.info( + "Dumping the collected trajectory, trajectories_collected = %s", + trajectories_collected) + dump_trajectory(train_env_trajectory_dump_dir, epoch, env_id, temperature, + str(random.randint(0, 2**31 - 1)), trajs) + else: + logging.info("Computation was aborted, a new policy is available.") + + # This maybe useless, since `abort_fn` will take care of it. We might want + # to have this here if abort_fn is False always. + # Do we have a newer policy? + policy_file_and_epoch = is_newer_policy_file_available(epoch) + if policy_file_and_epoch is None: + # Continue churning out these policies. + logging.info("We don't have a newer policy, continuing with the old one.") + continue + + # We have a newer policy, read it and update the parameters. + policy_file, epoch = policy_file_and_epoch + logging.info( + "We have a newer policy epoch [%s], file [%s], updating parameters.", + epoch, policy_file) + ppo_trainer.update_optimization_state( + output_dir, policy_and_value_opt_state=None) + logging.info("Parameters of PPOTrainer updated.") + + # Check that the epochs match. + assert epoch == ppo_trainer.epoch diff --git a/trax/rl/envs/async_trajectory_collector_lib_test.py b/trax/rl/envs/async_trajectory_collector_lib_test.py new file mode 100644 index 000000000..92af19c38 --- /dev/null +++ b/trax/rl/envs/async_trajectory_collector_lib_test.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow import test +from tensorflow.io import gfile +from trax.rl import ppo +from trax.rl.envs import async_trajectory_collector_lib as async_lib + + +class AsyncTrajectoryCollectorLibTest(test.TestCase): + + def test_get_newer_policy_model_file(self): + output_dir = self.get_temp_dir() + + def write_policy_model_file(epoch): + fname = ppo.get_policy_model_file_from_epoch(output_dir, epoch) + with gfile.GFile(fname, "w") as f: + f.write("some data") + return fname + + # No file exists currently. + self.assertIsNone(async_lib.get_newer_policy_model_file(output_dir)) + + # Write a policy model file. + epoch = 0 + policy_model_filename = write_policy_model_file(epoch) + + # See that we get it. + actual_policy_file, actual_epoch = ( + async_lib.get_newer_policy_model_file(output_dir, min_epoch=-1)) + + self.assertEqual(actual_policy_file, policy_model_filename) + self.assertEqual(actual_epoch, epoch) + + # If we now ask for a larger epoch, we don't get it. + self.assertIsNone( + async_lib.get_newer_policy_model_file(output_dir, min_epoch=0)) + + # Write a newer epoch and expect to get that with appropriate min_epoch. + epoch = 1 + policy_model_filename = write_policy_model_file(epoch) + actual_policy_file, actual_epoch = ( + async_lib.get_newer_policy_model_file(output_dir, min_epoch=0)) + self.assertEqual(actual_policy_file, policy_model_filename) + self.assertEqual(actual_epoch, epoch) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/envs/fake_env.py b/trax/rl/envs/fake_env.py new file mode 100644 index 000000000..9a99ee36a --- /dev/null +++ b/trax/rl/envs/fake_env.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A fake gym environment. + +Can specify either: +1. A done action, i.e. the action on which the environment returns done. +2. A done time-step, i.e. the time step at which the environment returns done. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np + + +class FakeEnv(gym.Env): + """A fake env which is either done with a specific action or a time-step.""" + + def __init__(self, + input_shape=(4,), + n_actions=2, + n_controls=1, + done_time_step=None, + done_action=None): + self._input_shape = input_shape + self._done_time_step = done_time_step + self._done_action = done_action + self._t = 0 + if n_controls == 1: + self.action_space = gym.spaces.Discrete(n_actions) + else: + self.action_space = gym.spaces.MultiDiscrete([n_actions] * n_controls) + self.observation_space = gym.spaces.Box( + low=-1.0, high=1.0, shape=input_shape) + + def _get_random_observation(self): + return np.random.random(self._input_shape) + + def reset(self): + self._t = 0 + return self._get_random_observation() + + def step(self, action): + assert self.action_space.contains(action) + done = False + if self._done_action is not None: + done = action == self._done_action + elif self._done_time_step is not None: + done = self._t == self._done_time_step + + reward = -1.0 if not done else 1.0 + self._t += 1 + return self._get_random_observation(), reward, done, {} diff --git a/trax/rl/envs/fake_env_test.py b/trax/rl/envs/fake_env_test.py new file mode 100644 index 000000000..3a16ff7d3 --- /dev/null +++ b/trax/rl/envs/fake_env_test.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.fake_env.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow import test +from trax.rl.envs import fake_env + + +class FakeEnvTest(test.TestCase): + + def test_done_action(self): + env = fake_env.FakeEnv(input_shape=(2, 3), + n_actions=10, + done_time_step=None, + done_action=9) + env.reset() + + # Actions 0 to 8 + for action in range(9): + _, reward, done, _ = env.step(action) + self.assertFalse(done) + self.assertEqual(-1.0, reward) + + _, reward, done, _ = env.step(9) + self.assertTrue(done) + self.assertEqual(1.0, reward) + + def test_done_time_step(self): + env = fake_env.FakeEnv(input_shape=(2, 3), + n_actions=10, + done_time_step=10, + done_action=None) + env.reset() + + # Take 10 steps. + for _ in range(10): + _, reward, done, _ = env.step(0) + self.assertFalse(done) + self.assertEqual(-1.0, reward) + + # Take final time-step, this is the time-step numbered 10 since time-steps + # are 0 indexed. + _, reward, done, _ = env.step(0) + self.assertTrue(done) + self.assertEqual(1.0, reward) + +if __name__ == '__main__': + test.main() diff --git a/trax/rl/envs/online_tune.py b/trax/rl/envs/online_tune.py new file mode 100644 index 000000000..bcef89fe0 --- /dev/null +++ b/trax/rl/envs/online_tune.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for OnlineTuneEnv.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +LEARNING_RATE_METRIC = ("train", "training/learning_rate") + + +def historical_metric_values( + history, metric, observation_range=(-np.inf, np.inf)): + """Converts a metric stream from a trax History object into a numpy array.""" + metric_sequence = history.get(*metric) + metric_values = np.array([ + metric_value for (_, metric_value) in metric_sequence + ]) + return np.clip(metric_values, *observation_range) + + +def history_to_observations(history, metrics, observation_range, include_lr): + """Converts a trax History object into a sequence of observations.""" + observation_dimensions = [ + historical_metric_values(history, metric, observation_range) + for metric in metrics + ] + if include_lr: + # Logartihm of the learning rate. + observation_dimensions.append(np.log(historical_metric_values( + history, LEARNING_RATE_METRIC, observation_range + ))) + return np.stack(observation_dimensions, axis=1) + + +def new_learning_rate(action, history, action_multipliers, max_lr): + """Calculates a new learning rate based on an action.""" + learning_rates = historical_metric_values(history, LEARNING_RATE_METRIC) + assert learning_rates.shape[0] > 0, "No last learning rate found in history." + current_lr = learning_rates[-1] + return min(current_lr * action_multipliers[action], max_lr) diff --git a/trax/rl/envs/online_tune_env.py b/trax/rl/envs/online_tune_env.py new file mode 100644 index 000000000..be459d2e1 --- /dev/null +++ b/trax/rl/envs/online_tune_env.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An environment for tuning model hyperparameters during training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import gym +from tensorflow.io import gfile +from trax import inputs as trax_inputs +from trax import layers +from trax import models as trax_models +from trax import optimizers as trax_opt +from trax import trainer_lib +from trax.rl import online_tune + + +class OnlineTuneEnv(gym.Env): + """An environment for tuning model hyperparameters during training. + + A rollout is one instance of training a specific model on a specific problem. + Observations are the values of some evaluation metric. Actions control + hyperparameter changes during training. Reward is the change of the evaluation + metric. One environment step corresponds to a fixed number of training steps. + + For now we only support tuning the learning rate. + """ + + # Chosen so that the opposite actions cancel each other out, so random walk + # has a median of 1. + DEFAULT_ACTION_MULTIPLIERS = [1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5] + + def __init__(self, + output_dir, + model=trax_models.TransformerLM, + trainer_class=trainer_lib.Trainer, + loss_fn=layers.CrossEntropyLossScalar, + optimizer=trax_opt.Adafactor, + inputs=trax_inputs.inputs, + action_multipliers=None, + observation_metrics=( + ("train", "metrics/accuracy"), + ("train", "metrics/loss"), + ("eval", "metrics/accuracy"), + ("eval", "metrics/loss"), + ), + include_controls_in_observation=False, + reward_metric=("eval", "metrics/accuracy"), + train_steps=100, + eval_steps=10, + env_steps=100, + # This is a tuple instead of a dict because the controls are + # ordered in the action space. + control_configs=( + # (name, start, (low, high), flip) + ("learning_rate", 1e-3, (1e-9, 10.0), False), + ), + nontrainable_param_map=None, + observation_range=(0.0, 10.0), + # Don't save checkpoints by default, as they tend to use a lot of + # space. + should_save_checkpoints=False, + # Same here. + should_write_summaries=False, + has_weights=False, + mask_id=None): + if action_multipliers is None: + action_multipliers = self.DEFAULT_ACTION_MULTIPLIERS + self._model = model + # Initialize Trainer in OnlineTuneEnv lazily to prevent long startup in the + # async setup, where we just use the environments as containers for + # trajectories. + self._trainer_fn = functools.partial( + trainer_class, + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + lr_schedule=(lambda history: lambda step: self._current_controls), + inputs=inputs, + should_save_checkpoints=should_save_checkpoints, + should_write_summaries=should_write_summaries, + nontrainable_param_map=nontrainable_param_map, + has_weights=has_weights, + mask_id=mask_id, + ) + self._trainer = None + self._action_multipliers = action_multipliers + self._observation_metrics = observation_metrics + self._include_controls_in_observation = include_controls_in_observation + self._reward_metric = reward_metric + self._train_steps = train_steps + self._eval_steps = eval_steps + self._env_steps = env_steps + self._control_configs = control_configs + self._observation_range = observation_range + + self._output_dir = output_dir + gfile.makedirs(self._output_dir) + # Actions are indices in self._action_multipliers. + self.action_space = gym.spaces.MultiDiscrete( + [len(self._action_multipliers)] * len(self._control_configs) + ) + # Observation is a vector with the values of the metrics specified in + # observation_metrics plus optionally the current controls. + observation_dim = ( + len(self._observation_metrics) + + int(self._include_controls_in_observation) * len(self._control_configs) + ) + + (obs_low, obs_high) = observation_range + self.observation_space = gym.spaces.Box( + # Observations are clipped to this range. + low=obs_low, high=obs_high, shape=(observation_dim,), + ) + + @property + def _next_trajectory_dir(self): + """Assigns a new output dir for a trajectory under self._output_dir. + + Directory names are consecutive integers starting from zero. New directory + index is assigned as the maximum of past indices plus one. Directories that + are not integers are ignored. + + Returns: + A path of the new directory. + """ + trajectory_dirs = gfile.listdir(self._output_dir) + + def int_or_none(s): + try: + return int(s) + except TypeError: + return None + + past_trajectory_ids = [ + trajectory_id for trajectory_id in map(int_or_none, trajectory_dirs) + if trajectory_id is not None] + next_trajectory_id = max([-1] + past_trajectory_ids) + 1 + + return os.path.join(self._output_dir, str(next_trajectory_id)) + + @property + def _current_reward_metric(self): + metric_values = online_tune.historical_metric_values( + self._trainer.state.history, + self._reward_metric, + ) + assert metric_values.shape[0] > 0, ( + "No values in history for metric {}.".format(self._reward_metric)) + return metric_values[-1] + + @property + def _current_observation(self): + observations = online_tune.history_to_observations( + self._trainer.state.history, + self._observation_metrics, + self._observation_range, + self._control_configs if self._include_controls_in_observation + else None, + ) + assert observations.shape[0] > 0, "No values in history for any metric." + return observations[-1, :] + + @property + def trainer(self): + if self._trainer is None: + raise ValueError("The environment has to be reset first.") + return self._trainer + + def reset(self): + if self._trainer is None: + self._trainer = self._trainer_fn() + self._current_controls = { + name: start_value + for (name, start_value, _, _) in self._control_configs + } + self._step = 0 + self._trainer.reset(output_dir=self._next_trajectory_dir) + self._trainer.evaluate(self._eval_steps) + return self._current_observation + + def step(self, action): + """Step the environment. + + One environment step corresponds to self.train_steps training steps. + + Args: + action: (int) Action to take. An index in self.action_multipliers. + + Returns: + Tuple (observation, reward, done, info). observation is a singleton vector + with the current value of the metric. reward is the difference in the + metric since the last step. done is set after reaching self.env_steps + environment steps. info is an empty dict. + """ + self._current_controls = { + # name: value + control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension + control_config, + control_action, + self._trainer.state.history, + self._action_multipliers, + ) + for (control_action, control_config) in zip( + action, self._control_configs + ) + } + last_reward_metric = self._current_reward_metric + self._trainer.train_epoch(self._train_steps, self._eval_steps) + self._step += 1 + current_reward_metric = self._current_reward_metric + observation = self._current_observation + reward = current_reward_metric - last_reward_metric + done = self._step == self._env_steps + return (observation, reward, done, {}) diff --git a/trax/rl/envs/online_tune_env_test.py b/trax/rl/envs/online_tune_env_test.py new file mode 100644 index 000000000..92c3c4d0c --- /dev/null +++ b/trax/rl/envs/online_tune_env_test.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.online_tune_env.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np +from tensorflow import test +from tensorflow.io import gfile +from trax import inputs as trax_inputs +from trax import models +from trax import optimizers +from trax import trainer_lib +from trax.rl import online_tune +from trax.rl.envs import online_tune_env + +HISTORY_MODE = "eval" +METRIC = "metrics/accuracy" + + +class MockTrainer(trainer_lib.Trainer): + + def __init__(self, metrics_to_report, *args, **kwargs): + super(MockTrainer, self).__init__(*args, **kwargs) + self.controls = [] + self.init_metrics_to_report = metrics_to_report + self.metrics_to_report = None + + def reset(self, output_dir): + super(MockTrainer, self).reset(output_dir) + # Copy the sequence to a list so we can modify it later. + self.metrics_to_report = list(self.init_metrics_to_report) + + def train_epoch(self, epoch_steps, eval_steps): + del epoch_steps + self.controls.append(self.nontrainable_params) + self.evaluate(eval_steps) + + def evaluate(self, eval_steps): + del eval_steps + self.state.history.append( + mode=HISTORY_MODE, + metric=METRIC, + step=self.step, + value=self.metrics_to_report.pop(0)) + for (name, value) in self.nontrainable_params.items(): + (mode, metric) = online_tune.control_metric(name) + self.state.history.append( + mode=mode, + metric=metric, + step=self.step, + value=value) + + +class OnlineTuneTest(test.TestCase): + + @staticmethod + def _create_env( + output_dir, metrics_to_report=(0.0,), action_multipliers=(1,) + ): + return online_tune_env.OnlineTuneEnv( + trainer_class=functools.partial(MockTrainer, metrics_to_report), + model=functools.partial( + models.MLP, n_hidden_layers=0, n_output_classes=1), + inputs=functools.partial( + trax_inputs.random_inputs, + input_shape=(1, 1), + input_dtype=np.float32, + output_shape=(1, 1), + output_dtype=np.float32), + optimizer=optimizers.Momentum, + control_configs=( + ("learning_rate", 1e-3, (1e-9, 10.0), False), + ("weight_decay_rate", 1e-5, (1e-9, 0.1), False), + ), + include_controls_in_observation=False, + output_dir=output_dir, + action_multipliers=action_multipliers, + observation_metrics=[(HISTORY_MODE, METRIC)], + reward_metric=(HISTORY_MODE, METRIC), + train_steps=1, + eval_steps=1, + env_steps=(len(metrics_to_report) - 1)) + + def test_communicates_with_trainer(self): + action_multipliers = [0.8, 1.0, 1.25] + metrics_to_report = [0.1, 0.5, 0.8, 0.9] + actions_to_take = [[0, 1], [1, 2], [2, 0]] + expected_observations = np.expand_dims(metrics_to_report, axis=1) + # Metric difference in consecutive timesteps. + expected_rewards = [0.4, 0.3, 0.1] + expected_dones = [False, False, True] + expected_controls = [ + {"learning_rate": 0.0008, "weight_decay_rate": 1e-5}, + {"learning_rate": 0.0008, "weight_decay_rate": 1.25e-5}, + {"learning_rate": 0.001, "weight_decay_rate": 1e-5}, + ] + + env = self._create_env( + output_dir=self.get_temp_dir(), + metrics_to_report=metrics_to_report, + action_multipliers=action_multipliers) + actual_observations = [env.reset()] + actual_rewards = [] + actual_dones = [] + for action in actions_to_take: + (observation, reward, done, _) = env.step(action) + actual_observations.append(observation) + actual_rewards.append(reward) + actual_dones.append(done) + + np.testing.assert_allclose(actual_observations, expected_observations) + np.testing.assert_allclose(actual_rewards, expected_rewards) + self.assertEqual(actual_dones, expected_dones) + def get_control(name, controls): + return [control[name] for control in controls] + for name in ("learning_rate", "weight_decay_rate"): + np.testing.assert_allclose( + get_control(name, env.trainer.controls), + get_control(name, expected_controls), + ) + + def test_creates_new_trajectory_dirs(self): + output_dir = self.get_temp_dir() + env = self._create_env(output_dir=output_dir) + self.assertEqual(set(gfile.listdir(output_dir)), set()) + env.reset() + self.assertEqual(set(gfile.listdir(output_dir)), {"0"}) + env.reset() + self.assertEqual(set(gfile.listdir(output_dir)), {"0", "1"}) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/envs/online_tune_test.py b/trax/rl/envs/online_tune_test.py new file mode 100644 index 000000000..652b2444a --- /dev/null +++ b/trax/rl/envs/online_tune_test.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.online_tune.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow import test +from trax import history as trax_history +from trax.rl.envs import online_tune + + +class OnlineTuneTest(test.TestCase): + + def _append_metrics(self, h, metric, values): + for (i, value) in enumerate(values): + h.append(*metric, step=i, value=value) + + def test_retrieves_historical_metric_values(self): + history = trax_history.History() + self._append_metrics(history, ("train", "accuracy"), [0.1, 0.73]) + metric_values = online_tune.historical_metric_values( + history, metric=("train", "accuracy"), observation_range=(0, 5)) + np.testing.assert_array_equal(metric_values, [0.1, 0.73]) + + def test_clips_historical_metric_values(self): + history = trax_history.History() + self._append_metrics(history, ("train", "loss"), [-10, 10]) + metric_values = online_tune.historical_metric_values( + history, metric=("train", "loss"), observation_range=(-1, 1)) + np.testing.assert_array_equal(metric_values, [-1, 1]) + + def test_converts_history_to_observations_without_learning_rate(self): + history = trax_history.History() + self._append_metrics(history, ("train", "loss"), [3.0, 1.07]) + self._append_metrics(history, ("eval", "accuracy"), [0.12, 0.68]) + observations = online_tune.history_to_observations( + history, + metrics=(("eval", "accuracy"), ("train", "loss")), + observation_range=(0, 5), + include_lr=False, + ) + np.testing.assert_array_equal(observations, [[0.12, 3.0], [0.68, 1.07]]) + + def test_converts_history_to_observations_with_learning_rate(self): + history = trax_history.History() + self._append_metrics( + history, ("train", "training/learning_rate"), [1e-3, 1e-4]) + observations = online_tune.history_to_observations( + history, + metrics=(), + observation_range=(0, 5), + include_lr=True, + ) + self.assertEqual(observations.shape, (2, 1)) + ((log_lr_1,), (log_lr_2,)) = observations + self.assertGreater(log_lr_1, log_lr_2) + + def test_clips_observations(self): + history = trax_history.History() + self._append_metrics(history, ("eval", "loss"), [-10, 10]) + observations = online_tune.history_to_observations( + history, + metrics=(("eval", "loss"),), + observation_range=(-2, 2), + include_lr=False, + ) + np.testing.assert_array_equal(observations, [[-2], [2]]) + + def test_calculates_new_learning_rate(self): + history = trax_history.History() + self._append_metrics( + history, online_tune.LEARNING_RATE_METRIC, [1e-2, 1e-3]) + new_lr = online_tune.new_learning_rate( + action=2, + history=history, + action_multipliers=(0.5, 1.0, 2.0), + max_lr=1.0, + ) + np.testing.assert_almost_equal(new_lr, 2e-3) + + def test_clips_new_learning_rate(self): + history = trax_history.History() + self._append_metrics(history, online_tune.LEARNING_RATE_METRIC, [1e-3]) + new_lr = online_tune.new_learning_rate( + action=0, + history=history, + action_multipliers=(4.0, 1.0, 0.25), + max_lr=3e-3, + ) + np.testing.assert_almost_equal(new_lr, 3e-3) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/online_tune.py b/trax/rl/online_tune.py new file mode 100644 index 000000000..47465e066 --- /dev/null +++ b/trax/rl/online_tune.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for OnlineTuneEnv.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def historical_metric_values(history, metric): + """Converts a metric stream from a trax History object into a numpy array.""" + metric_sequence = history.get(*metric) + metric_values = np.array([ + metric_value for (_, metric_value) in metric_sequence + ]) + if np.any(np.isnan(metric_values)): + # Zero out all observations if any element is NaN. This way the agent + # doesn't get any rewards, so it learns to avoid those regions. + metric_values[:] = 0.0 + return metric_values + + +def control_to_observation(control_values, control_config, observation_range): + """Flips, logarithms, clips and scales the control to observation_range.""" + (_, _, (low, high), flip) = control_config + def transform(x): + return np.log(maybe_flip(x, flip)) + (log_control_values, log_low, log_high) = map( + transform, (control_values, low, high) + ) + if flip: + (log_low, log_high) = (log_high, log_low) + log_control_values = np.clip(log_control_values, log_low, log_high) + # Rescale the log control values to the observation range. + (obs_low, obs_high) = observation_range + return ( + (log_control_values - log_low) / (log_high - log_low) * + (obs_high - obs_low) + obs_low + ) + + +def control_metric(name): + """Returns the (mode, metric) pair in History for the given control.""" + return ("train", "training/{}".format(name)) + + +def maybe_flip(value, flip): + """Flips a control (or not). + + Meant to translate controls that naturally take values close to 1 + (e.g. momentum) to a space where multiplication makes sense (i.e. close to 0). + + Args: + value: float or numpy array, value of the control. + flip: bool, whether to flip or not. + + Returns: + Either value or 1 - value based on flip. + """ + if flip: + value = 1 - value + return value + + +def history_to_observations( + history, metrics, observation_range, control_configs=None): + """Converts a trax History object into a sequence of observations.""" + (obs_low, obs_high) = observation_range + observation_dimensions = [ + np.clip(historical_metric_values(history, metric), obs_low, obs_high) + for metric in metrics + ] + if control_configs is not None: + for control_config in control_configs: + (control_name, _, _, _) = control_config + observation_dimensions.append(control_to_observation( + historical_metric_values(history, control_metric(control_name)), + control_config, + observation_range, + )) + return np.stack(observation_dimensions, axis=1) + + +def update_control(control_config, action, history, action_multipliers): + """Calculates a new value of a control based on an action.""" + (name, _, (low, high), flip) = control_config + metric = control_metric(name) + control_values = historical_metric_values(history, metric) + assert control_values.shape[0] > 0, ( + "No last control {} found in history.".format(name)) + current_control = control_values[-1] + (current_control, low, high) = maybe_flip( + np.array([current_control, low, high]), flip + ) + if flip: + (low, high) = (high, low) + new_control = np.clip( + current_control * action_multipliers[action], low, high + ) + return maybe_flip(new_control, flip) diff --git a/trax/rl/online_tune_test.py b/trax/rl/online_tune_test.py new file mode 100644 index 000000000..1786e99c1 --- /dev/null +++ b/trax/rl/online_tune_test.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.online_tune.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow import test +from trax import history as trax_history +from trax.rl import online_tune + + +class OnlineTuneTest(test.TestCase): + + def _append_metrics(self, h, metric, values): + for (i, value) in enumerate(values): + h.append(*metric, step=i, value=value) + + def test_retrieves_historical_metric_values(self): + history = trax_history.History() + self._append_metrics(history, ("train", "accuracy"), [0.1, 0.73]) + metric_values = online_tune.historical_metric_values( + history, metric=("train", "accuracy") + ) + np.testing.assert_array_equal(metric_values, [0.1, 0.73]) + + def test_converts_control_to_log_scale_without_flipping(self): + config = ("weight_decay", None, (1e-5, 0.1), False) + controls = np.array([0.01, 0.02, 0.04]) + obs_range = (-1, 1) + obs = online_tune.control_to_observation(controls, config, obs_range) + np.testing.assert_almost_equal(obs[1] - obs[0], obs[2] - obs[1]) + + def test_converts_control_to_log_scale_with_flipping(self): + config = ("momentum", None, (0.5, 0.99), True) + controls = np.array([0.98, 0.96, 0.92]) + obs_range = (-1, 1) + obs = online_tune.control_to_observation(controls, config, obs_range) + np.testing.assert_almost_equal(obs[1] - obs[0], obs[2] - obs[1]) + + def test_clips_control_without_flipping(self): + config = ("weight_decay", None, (1e-5, 0.1), False) + controls = np.array([0.0, 0.2]) + obs_range = (-1, 1) + obs = online_tune.control_to_observation(controls, config, obs_range) + np.testing.assert_equal(obs, [-1, 1]) + + def test_clips_control_with_flipping(self): + config = ("momentum", None, (0.5, 0.99), True) + controls = np.array([0.4, 1.0]) + obs_range = (-1, 1) + obs = online_tune.control_to_observation(controls, config, obs_range) + np.testing.assert_equal(obs, [1, -1]) + + def test_rescales_control(self): + config = ("weight_decay", None, (1e-5, 0.1), False) + controls = np.array([4e-4, 3e-3, 2e-2]) + (obs_low, obs_high) = (103, 104) + obs = online_tune.control_to_observation( + controls, config, observation_range=(obs_low, obs_high), + ) + np.testing.assert_array_less(obs, [obs_high] * 3) + np.testing.assert_array_less([obs_low] * 3, obs) + + def test_converts_history_to_observations_without_controls(self): + history = trax_history.History() + self._append_metrics(history, ("train", "loss"), [1.0, 0.07]) + self._append_metrics(history, ("eval", "accuracy"), [0.12, 0.68]) + observations = online_tune.history_to_observations( + history, + metrics=(("eval", "accuracy"), ("train", "loss")), + observation_range=(-1, 1), + control_configs=None, + ) + np.testing.assert_array_almost_equal( + observations, [[0.12, 1.0], [0.68, 0.07]] + ) + + def test_converts_history_to_observations_with_controls(self): + history = trax_history.History() + self._append_metrics( + history, ("train", "training/learning_rate"), [1e-3, 1e-4]) + observations = online_tune.history_to_observations( + history, + metrics=(), + observation_range=(0, 5), + control_configs=( + ("learning_rate", None, (1e-9, 10.0), False), + ), + ) + self.assertEqual(observations.shape, (2, 1)) + ((log_lr_1,), (log_lr_2,)) = observations + self.assertGreater(log_lr_1, log_lr_2) + + def test_clips_observations(self): + history = trax_history.History() + self._append_metrics(history, ("eval", "loss"), [-10, 10]) + observations = online_tune.history_to_observations( + history, + metrics=(("eval", "loss"),), + observation_range=(-2, 2), + control_configs=None, + ) + np.testing.assert_array_equal(observations, [[-2], [2]]) + + def test_updates_control_without_flipping(self): + config = ("learning_rate", None, (1e-9, 10.0), False) + history = trax_history.History() + self._append_metrics( + history, online_tune.control_metric("learning_rate"), [1e-2, 1e-3]) + new_control = online_tune.update_control( + control_config=config, + action=2, + history=history, + action_multipliers=(0.5, 1.0, 2.0), + ) + np.testing.assert_almost_equal(new_control, 2e-3) + + def test_updates_control_with_flipping(self): + config = ("momentum", None, (0.5, 0.99), True) + history = trax_history.History() + self._append_metrics( + history, online_tune.control_metric("momentum"), [0.96, 0.98]) + new_control = online_tune.update_control( + control_config=config, + action=0, + history=history, + action_multipliers=(0.5, 1.0, 2.0), + ) + np.testing.assert_almost_equal(new_control, 0.99) + + def test_clips_updated_control_without_flipping(self): + config = ("learning_rate", None, (1e-9, 10.0), False) + history = trax_history.History() + self._append_metrics( + history, online_tune.control_metric("learning_rate"), [7.0]) + new_control = online_tune.update_control( + control_config=config, + action=2, + history=history, + action_multipliers=(0.5, 1.0, 2.0), + ) + np.testing.assert_almost_equal(new_control, 10.0) + + def test_clips_updated_control_with_flipping(self): + config = ("momentum", None, (0.5, 0.99), True) + history = trax_history.History() + self._append_metrics( + history, online_tune.control_metric("momentum"), [0.985]) + new_control = online_tune.update_control( + control_config=config, + action=0, + history=history, + action_multipliers=(0.5, 1.0, 2.0), + ) + np.testing.assert_almost_equal(new_control, 0.99) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/ppo.py b/trax/rl/ppo.py new file mode 100644 index 000000000..32cb6fc95 --- /dev/null +++ b/trax/rl/ppo.py @@ -0,0 +1,971 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO in JAX. + +Notation: + +B, scalar - batch size +RT, scalar - (reward time) number of time-steps in a trajectory, or the size + of the padded reward sequence. +AT, scalar - (action time) number of controls in a trajectory, or the size + of the policy network output. +OBS, tuple - shape of a singular observation from the environment. + Ex: For CartPole-v0 this is (4,) and Pong-v0 it's (210, 160, 3) +A, scalar - Number of actions, assuming a discrete space. + +Policy and Value function signatures: + +Policy Function :: [B, RT + 1] + OBS -> [B, AT, A] +Value Function :: [B, RT + 1] + OBS -> [B, AT] +Policy and Value Function :: [B, RT + 1] + OBS -> ([B, AT, A], [B, AT]) + +i.e. the policy net should take a batch of *trajectories* and at each time-step +in each batch deliver a probability distribution over actions. + +NOTE: It doesn't return logits, rather the expectation is that it returns +log-probabilities instead. + +NOTE: The policy and value functions need to take care to not take into account +future time-steps while deciding the actions (or value) for the current +time-step. + +Policy and Value Function produces a tuple of the expected output of a policy +function and a value function. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import itertools +import os +import re +import time + +from absl import logging +from jax import grad +from jax import jit +from jax import lax +from jax import numpy as np +import numpy as onp + +from tensor2tensor.envs import env_problem +from tensor2tensor.envs import env_problem_utils +from tensorflow.io import gfile +from trax import layers as tl +from trax import utils + + +def policy_and_value_net( + n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers +): + """A policy and value net function.""" + + # Layers. + + # Now, with the current logits, one head computes action probabilities and the + # other computes the value function. + # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. + + @tl.layer() + def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name + """Splits logits for actions in different controls and flattens controls.""" + return np.reshape(x, (x.shape[0], -1, n_actions)) + + if vocab_size is None: + # In continuous policies every element of the output sequence corresponds to + # an observation. + n_preds_per_input = n_controls + kwargs = {} + else: + # In discrete policies every element of the output sequence corresponds to + # a symbol in the discrete representation, and each control takes 1 symbol. + n_preds_per_input = 1 + kwargs = {"vocab_size": vocab_size} + + if two_towers: + layers = [ + tl.Dup(), + tl.Parallel( + [bottom_layers_fn(**kwargs), + tl.Dense(n_preds_per_input * n_actions), + FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter + tl.LogSoftmax()], + [bottom_layers_fn(**kwargs), + tl.Dense(n_preds_per_input), + tl.Flatten()], + ) + ] + else: + layers = [ + bottom_layers_fn(**kwargs), + tl.Dup(), + tl.Parallel( + [tl.Dense(n_preds_per_input * n_actions), + FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter + tl.LogSoftmax()], + [tl.Dense(n_preds_per_input), tl.Flatten()], + ) + ] + return tl.Model(layers) + + +def optimizer_fn(optimizer, net_params): + """Exposes a convenient interface for the optimizer. + + Args: + optimizer: Optimizer class to use. + net_params: A nested structure of network parameters. + + Returns: + A tuple (opt_state, opt_update, get_params), where: + opt_state: Pair (net_params, opt_slots) - initial optimization state. + opt_update: Function (step, grads, opt_state) -> opt_state doing one + optimization step. + get_params: Function opt_state -> net_params for extracting the network + parameters from the optimization state. + """ + opt = optimizer() + (init_slots, init_nontrainable_slots) = opt.tree_init(net_params) + init_state = (net_params, init_slots) + + def opt_update(step, grads, opt_state): + (params, slots) = opt_state + # Pass the initial nontrainable_slots as we don't tune them during training. + # (yet!) + return opt.tree_update(step, grads, params, slots, init_nontrainable_slots) + + def get_params(opt_state): + (params, _) = opt_state + return params + + return init_state, opt_update, get_params + + +# Should this be collect 'n' trajectories, or +# Run the env for 'n' steps and take completed trajectories, or +# Any other option? +def collect_trajectories(env, + policy_fn, + n_trajectories=1, + max_timestep=None, + reset=True, + len_history_for_policy=32, + boundary=32, + state=None, + temperature=1.0, + rng=None, + abort_fn=None, + raw_trajectory=False,): + """Collect trajectories with the given policy net and behaviour. + + Args: + env: A gym env interface, for now this is not-batched. + policy_fn: observations(B,RT+1) -> log-probabs(B, AT, A) callable. + n_trajectories: int, number of trajectories. + max_timestep: int or None, the index of the maximum time-step at which we + return the trajectory, None for ending a trajectory only when env returns + done. + reset: bool, true if we want to reset the envs. The envs are also reset if + max_max_timestep is None or < 0 + len_history_for_policy: int or None, the maximum history to keep for + applying the policy on. If None, use the full history. + boundary: int, pad the sequences to the multiples of this number. + state: state for `policy_fn`. + temperature: (float) temperature to sample action from policy_fn. + rng: jax rng, splittable. + abort_fn: callable, If not None, then at every env step call and abort the + trajectory collection if it returns True, if so reset the env and return + None. + raw_trajectory: bool, if True a list of trajectory.Trajectory objects is + returned, otherwise a list of numpy representations of + `trajectory.Trajectory` is returned. + + Returns: + A tuple (trajectory, number of trajectories that are done) + trajectory: list of (observation, action, reward) tuples, where each element + `i` is a tuple of numpy arrays with shapes as follows: + observation[i] = (B, T_i + 1) + action[i] = (B, T_i) + reward[i] = (B, T_i) + """ + + assert isinstance(env, env_problem.EnvProblem) + # This is an env_problem, run its collect function. + trajs, n_done, timing_info, state = env_problem_utils.play_env_problem_with_policy( + env, + policy_fn, + num_trajectories=n_trajectories, + max_timestep=max_timestep, + reset=reset, + len_history_for_policy=len_history_for_policy, + boundary=boundary, + state=state, + temperature=temperature, + rng=rng, + abort_fn=abort_fn, + raw_trajectory=raw_trajectory, + ) + # Skip returning raw_rewards here, since they aren't used. + + # t is the return value of Trajectory.as_numpy, so: + # (observation, action, processed_reward, raw_reward, infos) + return trajs, n_done, timing_info, state + + +# This function can probably be simplified, ask how? +# Can we do something much simpler than lax.pad, maybe np.pad? +# Others? + + +def get_padding_value(dtype): + """Returns the padding value given a dtype.""" + padding_value = None + if dtype == np.uint8: + padding_value = np.uint8(0) + elif dtype == np.uint16: + padding_value = np.uint16(0) + elif dtype == np.float32 or dtype == np.float64: + padding_value = 0.0 + else: + padding_value = 0 + assert padding_value is not None + return padding_value + + +# TODO(afrozm): Use np.pad instead and make jittable? +def pad_trajectories(trajectories, boundary=20): + """Pad trajectories to a bucket length that is a multiple of boundary. + + Args: + trajectories: list[(observation, actions, rewards)], where each observation + is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the + length of the list being B (batch size). + boundary: int, bucket length, the actions and rewards are padded to integer + multiples of boundary. + + Returns: + tuple: (padding lengths, reward_mask, padded_observations, padded_actions, + padded_rewards) where padded_observations is shaped (B, RT+1) + OBS and + padded_actions, padded_rewards & reward_mask are shaped (B, RT). + Where RT is max(t) rounded up to an integer multiple of boundary. + padded_length is how much padding we've added and + reward_mask is 1s for actual rewards and 0s for the padding. + """ + + # Let's compute max(t) over all trajectories. + t_max = max(r.shape[0] for (_, _, r, _) in trajectories) + + # t_max is rounded to the next multiple of `boundary` + boundary = int(boundary) + bucket_length = boundary * int(np.ceil(float(t_max) / boundary)) + + # So all obs will be padded to t_max + 1 and actions and rewards to t_max. + padded_observations = [] + padded_actions = [] + padded_rewards = [] + padded_infos = collections.defaultdict(list) + padded_lengths = [] + reward_masks = [] + + for (o, a, r, i) in trajectories: + # Determine the amount to pad, this holds true for obs, actions and rewards. + num_to_pad = bucket_length + 1 - o.shape[0] + padded_lengths.append(num_to_pad) + if num_to_pad == 0: + padded_observations.append(o) + padded_actions.append(a) + padded_rewards.append(r) + reward_masks.append(onp.ones_like(r, dtype=np.int32)) + if i: + for k, v in i.items(): + padded_infos[k].append(v) + continue + + # First pad observations. + padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] * (o.ndim - 1)) + + padding_value = get_padding_value(o.dtype) + action_padding_value = get_padding_value(a.dtype) + reward_padding_value = get_padding_value(r.dtype) + + padded_obs = lax.pad(o, padding_value, padding_config) + padded_observations.append(padded_obs) + + # Now pad actions and rewards. + padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] * (a.ndim - 1)) + padded_action = lax.pad(a, action_padding_value, padding_config) + padded_actions.append(padded_action) + + assert r.ndim == 1 + padding_config = ((0, num_to_pad, 0),) + padded_reward = lax.pad(r, reward_padding_value, padding_config) + padded_rewards.append(padded_reward) + + # Also create the mask to use later. + reward_mask = onp.ones_like(r, dtype=np.int32) + reward_masks.append(lax.pad(reward_mask, 0, padding_config)) + + if i: + for k, v in i.items(): + # Create a padding configuration for this value. + padding_config = [(0, num_to_pad, 0)] + [(0, 0, 0)] * (v.ndim - 1) + padded_infos[k].append(lax.pad(v, 0.0, tuple(padding_config))) + + # Now stack these padded_infos if they exist. + stacked_padded_infos = None + if padded_infos: + stacked_padded_infos = {k: np.stack(v) for k, v in padded_infos.items()} + + return padded_lengths, np.stack(reward_masks), np.stack( + padded_observations), np.stack(padded_actions), np.stack( + padded_rewards), stacked_padded_infos + + +def rewards_to_go(rewards, mask, gamma=0.99): + r"""Computes rewards to go. + + Reward to go is defined as follows, the discounted reward that we have to + yet collect, going forward from this point, i.e.: + + r2g_t = \sum_{l=0}^{\infty} (\gamma^{l} * reward_{t+l}) + + Args: + rewards: np.ndarray of shape (B, RT) of rewards. + mask: np.ndarray of shape (B, RT) of mask for the rewards. + gamma: float, discount factor. + + Returns: + rewards to go, np.ndarray of shape (B, RT). + """ + B, RT = rewards.shape # pylint: disable=invalid-name,unused-variable + + masked_rewards = rewards * mask # (B, RT) + + # The lax.scan version of this is slow, but we still show it here for + # completeness. + # rewards_rev = np.flip(masked_rewards, axis=1) # (B, T) flipped on time. + # rrt = np.transpose(rewards_rev) # (T, B) transpose to scan over time. + # + # def discounting_add(carry, reward): + # x = reward + (gamma * carry) + # return x, x + # + # _, ys = lax.scan(discounting_add, + # np.zeros_like(rrt[0], dtype=np.float32), + # rrt.astype(np.float32)) + # + # # ys is (T, B) and T is in reverse order. + # return np.flip(np.transpose(ys), axis=1) + + # We use the following recurrence relation, derived from the equation above: + # + # r2g[t+1] = (r2g[t] - r[t]) / gamma + # + # This means we'll need to calculate r2g[0] first and then r2g[1] and so on .. + # + # **However** this leads to overflows for long sequences: r2g[t] - r[t] > 0 + # and gamma < 1.0, so the division keeps increasing. + # + # So we just run the recurrence in reverse, i.e. + # + # r2g[t] = r[t] + (gamma*r2g[t+1]) + # + # This is much better, but might have lost updates since the (small) rewards + # at earlier time-steps may get added to a (very?) large sum. + + # Compute r2g_{T-1} at the start and then compute backwards in time. + r2gs = [masked_rewards[:, -1]] + + # Go from T-2 down to 0. + for t in reversed(range(RT - 1)): + r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1])) + + # The list should have length RT. + assert RT == len(r2gs) + + # First we stack them in the correct way to make it (B, RT), but these are + # still from newest (RT-1) to oldest (0), so then we flip it on time axis. + return np.flip(np.stack(r2gs, axis=1), axis=1) + + +@jit +def value_loss_given_predictions(value_prediction, + rewards, + reward_mask, + gamma=0.99, + epsilon=0.2, + value_prediction_old=None): + """Computes the value loss given the prediction of the value function. + + Args: + value_prediction: np.ndarray of shape (B, RT+1, 1) + rewards: np.ndarray of shape (B, RT) of rewards. + reward_mask: np.ndarray of shape (B, RT), the mask over rewards. + gamma: float, discount factor. + epsilon: float, clip-fraction, used if value_value_prediction_old isn't None + value_prediction_old: np.ndarray of shape (B, RT+1, 1) of value predictions + using the old parameters. If provided, we incorporate this in the loss as + well. This is from the OpenAI baselines implementation. + + Returns: + Pair (value_loss, summaries), where value_loss is the average L2 value loss, + averaged over instances where reward_mask is 1. Summaries is a dict of + summaries collected during value loss computation. + """ + + B, RT = rewards.shape # pylint: disable=invalid-name + assert (B, RT) == reward_mask.shape + assert (B, RT + 1) == value_prediction.shape + + value_prediction = value_prediction[:, :-1] * reward_mask # (B, RT) + r2g = rewards_to_go(rewards, reward_mask, gamma=gamma) # (B, RT) + loss = (value_prediction - r2g)**2 + + # From the baselines implementation. + if value_prediction_old is not None: + value_prediction_old = value_prediction_old[:, :-1] * reward_mask # (B, RT) + + v_clipped = value_prediction_old + np.clip( + value_prediction - value_prediction_old, -epsilon, epsilon) + v_clipped_loss = (v_clipped - r2g)**2 + loss = np.maximum(v_clipped_loss, loss) + + # Take an average on only the points where mask != 0. + value_loss = np.sum(loss) / np.sum(reward_mask) + + summaries = { + "value_loss": value_loss, + } + + return (value_loss, summaries) + + +def deltas(predicted_values, rewards, mask, gamma=0.99): + r"""Computes TD-residuals from V(s) and rewards. + + Where a `delta`, i.e. a td-residual is defined as: + + delta_{b,t} = r_{b,t} + \gamma * v_{b,t+1} - v_{b,t}. + + Args: + predicted_values: ndarray of shape (B, RT+1). NOTE: Expects axis 2 was + squeezed. These represent V(s_bt) for b < B and t < RT+1 + rewards: ndarray of shape (B, RT) of rewards. + mask: ndarray of shape (B, RT) of mask for rewards. + gamma: float, discount factor. + + Returns: + ndarray of shape (B, RT) of one-step TD-residuals. + """ + + # Predicted values at time t, cutting off the last to have shape (B, RT). + predicted_values_bt = predicted_values[:, :-1] + # Predicted values at time t+1, by cutting off the first to have shape (B, RT) + predicted_values_btplus1 = predicted_values[:, 1:] + # Return the deltas as defined above. + return (rewards + + (gamma * predicted_values_btplus1) - predicted_values_bt) * mask + + +def gae_advantages(td_deltas, mask, lambda_=0.95, gamma=0.99): + r"""Computes the GAE advantages given the one step TD-residuals. + + The formula for a GAE advantage estimator is as follows: + + A_{bt} = \sum_{l=0}^{\infty}(\gamma * \lambda)^{l}(\delta_{b,t+l}). + + Internally we just call rewards_to_go, since it is the same computation. + + Args: + td_deltas: np.ndarray of shape (B, RT) of one step TD-residuals. + mask: np.ndarray of shape (B, T) of mask for the residuals. It maybe the + case that the `td_deltas` are already masked correctly since they are + produced by `deltas(...)` + lambda_: float, lambda parameter for GAE estimators. + gamma: float, lambda parameter for GAE estimators. + + Returns: + GAE advantage estimates. + """ + + return rewards_to_go(td_deltas, mask, lambda_ * gamma) + + +def chosen_probabs(probab_actions, actions): + """Picks out the probabilities of the actions along batch and time-steps. + + Args: + probab_actions: ndarray of shape `[B, AT, A]`, where + probab_actions[b, t, i] contains the log-probability of action = i at + the t^th time-step in the b^th trajectory. + actions: ndarray of shape `[B, AT]`, with each entry in [0, A) denoting + which action was chosen in the b^th trajectory's t^th time-step. + + Returns: + `[B, AT, A]` ndarray with the log-probabilities of the chosen actions. + """ + B, AT = actions.shape # pylint: disable=invalid-name + assert (B, AT) == probab_actions.shape[:2] + return probab_actions[np.arange(B)[:, None], np.arange(AT), actions] + + +def compute_probab_ratios(p_new, p_old, actions, action_mask): + """Computes the probability ratios for each time-step in a trajectory. + + Args: + p_new: ndarray of shape [B, AT, A] of the log-probabilities that the + policy network assigns to all the actions at each time-step in each batch + using the old parameters. + p_old: ndarray of shape [B, AT, A], same as above, but using old policy + network parameters. + actions: ndarray of shape [B, AT] where each element is from [0, A). + action_mask: ndarray of shape [B, T] masking over probabilities. + + Returns: + probab_ratios: ndarray of shape [B, AT], where + probab_ratios_{b,t,} = p_new_{b,t,action_{b,t}} / + p_old_{b,t,action_{b,t}} + """ + + B, AT = actions.shape # pylint: disable=invalid-name + assert (B, AT) == p_old.shape[:2] + assert (B, AT) == p_new.shape[:2] + + logp_old = chosen_probabs(p_old, actions) + logp_new = chosen_probabs(p_new, actions) + + assert (B, AT) == logp_old.shape + assert (B, AT) == logp_new.shape + + # Since these are log-probabilities, we just subtract them. + probab_ratios = np.exp(logp_new - logp_old) * action_mask + assert (B, AT) == probab_ratios.shape + return probab_ratios + + +def clipped_probab_ratios(probab_ratios, epsilon=0.2): + return np.clip(probab_ratios, 1 - epsilon, 1 + epsilon) + + +def clipped_objective(probab_ratios, advantages, action_mask, epsilon=0.2): + advantages = advantages + return np.minimum( + probab_ratios * advantages, + clipped_probab_ratios(probab_ratios, epsilon=epsilon) * + advantages) * action_mask + + +@jit +def ppo_loss_given_predictions(log_probab_actions_new, + log_probab_actions_old, + value_predictions_old, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + gamma=0.99, + lambda_=0.95, + epsilon=0.2): + """PPO objective, with an eventual minus sign, given predictions.""" + B, RT = padded_rewards.shape # pylint: disable=invalid-name + _, AT, A = log_probab_actions_old.shape # pylint: disable=invalid-name + + assert (B, RT) == padded_rewards.shape + assert (B, AT) == padded_actions.shape + assert (B, RT) == reward_mask.shape + + assert (B, RT + 1) == value_predictions_old.shape + assert (B, AT, A) == log_probab_actions_old.shape + assert (B, AT, A) == log_probab_actions_new.shape + + assert (RT + 1, AT) == rewards_to_actions.shape + + # (B, RT) + td_deltas = deltas( + value_predictions_old, # (B, RT+1) + padded_rewards, + reward_mask, + gamma=gamma) + + # (B, RT) + advantages = gae_advantages( + td_deltas, reward_mask, lambda_=lambda_, gamma=gamma) + + # Normalize the advantages. + advantage_mean = np.mean(advantages) + advantage_std = np.std(advantages) + advantages = (advantages - advantage_mean) / (advantage_std + 1e-8) + + # Scatter advantages over padded_actions. + # rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward + # mask by 1. + advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions) + action_mask = np.dot( + np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions + ) + + # (B, AT) + ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old, + padded_actions, action_mask) + assert (B, AT) == ratios.shape + + # (B, AT) + objective = clipped_objective( + ratios, advantages, action_mask, epsilon=epsilon) + assert (B, AT) == objective.shape + + # () + average_objective = np.sum(objective) / np.sum(action_mask) + + # Loss is negative objective. + ppo_loss = -average_objective + + summaries = { + "ppo_loss": ppo_loss, + "advantage_mean": advantage_mean, + "advantage_std": advantage_std, + } + + return (ppo_loss, summaries) + + +@jit +def combined_loss_given_predictions(log_probab_actions_new, + log_probab_actions_old, + value_prediction_new, + value_prediction_old, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + gamma=0.99, + lambda_=0.95, + epsilon=0.2, + c1=1.0, + c2=0.01): + """Computes the combined (clipped loss + value loss) given predictions.""" + # Sum values over symbols in an action's representation, because it's a simple + # way of going from AT to RT+1 and does not decrease the expressive power. + value_prediction_old = np.dot( + value_prediction_old, rewards_to_actions.transpose() + ) + value_prediction_new = np.dot( + value_prediction_new, rewards_to_actions.transpose() + ) + (value_loss, value_summaries) = value_loss_given_predictions( + value_prediction_new, + padded_rewards, + reward_mask, + gamma=gamma, + value_prediction_old=value_prediction_old, + epsilon=epsilon) + (ppo_loss, ppo_summaries) = ppo_loss_given_predictions( + log_probab_actions_new, + log_probab_actions_old, + value_prediction_old, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + gamma=gamma, + lambda_=lambda_, + epsilon=epsilon) + # Pad the reward mask to be compatible with rewards_to_actions. + padded_reward_mask = np.pad(reward_mask, ((0, 0), (0, 1))) + action_mask = np.dot(padded_reward_mask, rewards_to_actions) + entropy_bonus = masked_entropy(log_probab_actions_new, action_mask) + combined_loss_ = ppo_loss + (c1 * value_loss) - (c2 * entropy_bonus) + + summaries = { + "combined_loss": combined_loss_, + "entropy_bonus": entropy_bonus, + } + for loss_summaries in (value_summaries, ppo_summaries): + summaries.update(loss_summaries) + + return (combined_loss_, (ppo_loss, value_loss, entropy_bonus), summaries) + + +@functools.partial(jit, static_argnums=(3,)) +def combined_loss(new_params, + log_probab_actions_old, + value_predictions_old, + policy_and_value_net_apply, + padded_observations, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + gamma=0.99, + lambda_=0.95, + epsilon=0.2, + c1=1.0, + c2=0.01, + state=None, + rng=None): + """Computes the combined (clipped loss + value loss) given observations.""" + (log_probab_actions_new, value_predictions_new) = ( + policy_and_value_net_apply( + padded_observations, params=new_params, state=state, rng=rng)) + + (loss, component_losses, summaries) = combined_loss_given_predictions( + log_probab_actions_new, + log_probab_actions_old, + value_predictions_new, + value_predictions_old, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + gamma=gamma, + lambda_=lambda_, + epsilon=epsilon, + c1=c1, + c2=c2, + ) + return (loss, component_losses, summaries, state) + + +@functools.partial(jit, static_argnums=(2, 3, 4)) +def policy_and_value_opt_step(i, + opt_state, + opt_update, + get_params, + policy_and_value_net_apply, + log_probab_actions_old, + value_predictions_old, + padded_observations, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + c1=1.0, + c2=0.01, + gamma=0.99, + lambda_=0.95, + epsilon=0.1, + state=None, + rng=None): + """Policy and Value optimizer step.""" + + # Combined loss function given the new params. + def policy_and_value_loss(params, state): + """Returns the combined loss given just parameters.""" + (loss, _, _, state) = combined_loss( + params, + log_probab_actions_old, + value_predictions_old, + policy_and_value_net_apply, + padded_observations, + padded_actions, + rewards_to_actions, + padded_rewards, + reward_mask, + c1=c1, + c2=c2, + gamma=gamma, + lambda_=lambda_, + epsilon=epsilon, + state=state, + rng=rng) + return loss, state + + new_params = get_params(opt_state) + g, state = grad(policy_and_value_loss, has_aux=True)(new_params, state) + # TODO(afrozm): Maybe clip gradients? + return opt_update(i, g, opt_state), state + + +def get_time(t1, t2=None): + if t2 is None: + t2 = time.time() + return round((t2 - t1) * 1000, 2) + + +def approximate_kl(log_prob_new, log_prob_old, mask): + """Computes the approximate KL divergence between the old and new log-probs. + + Args: + log_prob_new: (B, AT, A) log probs new + log_prob_old: (B, AT, A) log probs old + mask: (B, AT) + + Returns: + Approximate KL. + """ + diff = log_prob_old - log_prob_new + # Mask out the irrelevant part. + diff *= mask[:, :, np.newaxis] # make mask (B, RT, 1) + # Average on non-masked part. + return np.sum(diff) / np.sum(mask) + + +def masked_entropy(log_probs, mask): + """Computes the entropy for the given log-probs. + + Args: + log_probs: (B, AT, A) log probs + mask: (B, AT) mask. + + Returns: + Entropy. + """ + # Mask out the irrelevant part. + lp = log_probs * mask[:, :, np.newaxis] # make mask (B, RT, 1) + p = np.exp(lp) * mask[:, :, np.newaxis] # (B, RT, 1) + # Average on non-masked part and take negative. + return -(np.sum(lp * p) / np.sum(mask)) + + +def get_policy_model_files(output_dir): + return list( + reversed( + sorted(gfile.glob(os.path.join(output_dir, "model-??????.pkl"))))) + + +def get_epoch_from_policy_model_file(policy_model_file): + base_name = os.path.basename(policy_model_file) + return int(re.match(r"model-(\d+).pkl", base_name).groups()[0]) + + +def get_policy_model_file_from_epoch(output_dir, epoch): + return os.path.join(output_dir, "model-%06d.pkl" % epoch) + + +def maybe_restore_opt_state(output_dir, + policy_and_value_opt_state=None, + policy_and_value_state=None): + """Maybe restore the optimization state from the checkpoint dir. + + Optimization state includes parameters and optimizer slots. + + Args: + output_dir: Directory where saved model checkpoints are stored. + policy_and_value_opt_state: Default optimization state, returned if model + isn't found. + policy_and_value_state: state of the policy and value network. + + Returns: + tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the + epoch from which we restored the optimization state, 0 if no checkpoint was + found, and opt_step is the total optimization step (sum of all optimization + steps made up to the current epoch). + """ + pkl_module = utils.get_pickle_module() + epoch = 0 + total_opt_step = 0 + for model_file in get_policy_model_files(output_dir): + logging.info("Trying to restore model from %s", model_file) + try: + with gfile.GFile(model_file, "rb") as f: + policy_and_value_opt_state, policy_and_value_state, total_opt_step = ( + pkl_module.load(f)) + epoch = get_epoch_from_policy_model_file(model_file) + break + except EOFError as e: + logging.error("Unable to load model from: %s with %s", model_file, e) + # Try an older version. + continue + return ( + policy_and_value_opt_state, + policy_and_value_state, + epoch, + total_opt_step, + ) + + +LAST_N_POLICY_MODELS_TO_KEEP = 5 + + +def save_opt_state(output_dir, + policy_and_value_opt_state, + policy_and_value_state, + epoch, + total_opt_step): + """Saves the policy and value network optimization state etc.""" + pkl_module = utils.get_pickle_module() + old_model_files = get_policy_model_files(output_dir) + params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch) + with gfile.GFile(params_file, "wb") as f: + pkl_module.dump( + (policy_and_value_opt_state, policy_and_value_state, total_opt_step), f) + # Keep the last k model files lying around (note k > 1 because the latest + # model file might be in the process of getting read async). + for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]: + if path != params_file: + gfile.remove(path) + + +def init_policy_from_world_model_checkpoint(policy_params, model_output_dir): + """Initializes policy parameters from world model parameters.""" + pkl_module = utils.get_pickle_module() + params_file = os.path.join(model_output_dir, "model.pkl") + # Don't use trax.restore_state to avoid a circular import. + with gfile.GFile(params_file, "rb") as f: + model_params = pkl_module.load(f)[0][0] + # TODO(pkozakowski): The following, brittle line of code is hardcoded for + # transplanting parameters from TransformerLM to TransformerDecoder-based + # policy network of the same configuration. Figure out a more general method. + policy_params[0] = model_params[0][1:-2] + return policy_params + + +def write_eval_reward_summaries(reward_stats_by_mode, summary_writer, epoch): + """Writes evaluation reward statistics to summary and logs them. + + Args: + reward_stats_by_mode: Nested dict of structure: { + "raw": { + : { + "mean": , + "std": , }, + : ... }, + "processed": ... } + summary_writer: jaxboard.SummaryWriter. + epoch: Current epoch number. + """ + for (reward_mode, reward_stats_by_temp) in reward_stats_by_mode.items(): + for (temperature, reward_stats) in reward_stats_by_temp.items(): + for (stat_name, stat) in reward_stats.items(): + summary_writer.scalar( + "eval/{reward_mode}_reward_{stat_name}/" + "temperature_{temperature}".format( + reward_mode=reward_mode, + stat_name=stat_name, + temperature=temperature), + stat, + step=epoch) + logging.info( + "Epoch [% 6d] Policy Evaluation (%s reward) " + "[temperature %.2f] = %10.2f (+/- %.2f)", epoch, reward_mode, + temperature, reward_stats["mean"], reward_stats["std"]) + + +def shuffled_index_batches(dataset_size, batch_size): + """Generates batches of shuffled indices over a dataset.""" + def shuffled_indices(): + while True: + perm = onp.random.permutation(dataset_size) + for x in perm: + yield x + + indices = shuffled_indices() + while True: + yield onp.array(list(itertools.islice(indices, int(batch_size)))) diff --git a/trax/rl/ppo_test.py b/trax/rl/ppo_test.py new file mode 100644 index 000000000..301f73435 --- /dev/null +++ b/trax/rl/ppo_test.py @@ -0,0 +1,643 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.ppo.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import itertools + +import jax +from jax import random as jax_random +import numpy as np +from tensorflow import test +from tensorflow.io import gfile +from trax import inputs +from trax import layers +from trax import models +from trax import trainer_lib +from trax.rl import ppo + + +class PpoTest(test.TestCase): + + def setUp(self): + super(PpoTest, self).setUp() + self.rng_key = trainer_lib.get_random_number_generator_and_set_seed(0) + + def test_get_policy_model_files(self): + output_dir = self.get_temp_dir() + + def write_policy_model_file(epoch): + with gfile.GFile( + ppo.get_policy_model_file_from_epoch(output_dir, epoch), "w") as f: + f.write("some data") + + epochs = [200, 100, 300] + + # 300, 200, 100 + expected_policy_model_files = [ + output_dir + "/model-000300.pkl", + output_dir + "/model-000200.pkl", + output_dir + "/model-000100.pkl", + ] + + for epoch in epochs: + write_policy_model_file(epoch) + + policy_model_files = ppo.get_policy_model_files(output_dir) + + self.assertEqual(expected_policy_model_files, policy_model_files) + + gfile.rmtree(output_dir) + + def test_get_epoch_from_policy_model_file(self): + self.assertEqual(0, + ppo.get_epoch_from_policy_model_file("model-000000.pkl")) + self.assertEqual(123456, + ppo.get_epoch_from_policy_model_file("model-123456.pkl")) + + def test_get_policy_model_file_from_epoch(self): + self.assertEqual("/tmp/model-000000.pkl", + ppo.get_policy_model_file_from_epoch("/tmp", 0)) + self.assertEqual("/tmp/model-123456.pkl", + ppo.get_policy_model_file_from_epoch("/tmp", 123456)) + + def test_policy_and_value_net(self): + observation_shape = (3, 4, 5) + batch_observation_shape = (1, 1) + observation_shape + n_actions = 2 + n_controls = 3 + pnv_model = ppo.policy_and_value_net( + n_controls=n_controls, + n_actions=n_actions, + vocab_size=None, + bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], + two_towers=True, + ) + _, _ = pnv_model.initialize_once( + batch_observation_shape, np.float32, self.rng_key) + + batch = 2 + time_steps = 10 + batch_of_observations = np.random.uniform( + size=(batch, time_steps) + observation_shape) + pnv_output = pnv_model(batch_of_observations) + + # Output is a list, first is probab of actions and the next is value output. + self.assertEqual(2, len(pnv_output)) + self.assertEqual( + (batch, time_steps * n_controls, n_actions), pnv_output[0].shape) + self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape) + + def test_pad_trajectories(self): + observation_shape = (2, 3, 4) + trajectories = [] + n_trajectories = 7 + n_actions = 10 + + # Time-steps are between [min_allowable_time_step, max_allowable_time_step] + max_allowable_time_step = 19 + min_allowable_time_step = 5 + + # The actual max we see in the data. + max_time_step = -1 + + # Bucket length. + bucket_length = 15 + + # Make `n_trajectories` random trajectories. + for i in range(n_trajectories): + time_steps = np.random.randint(min_allowable_time_step, + max_allowable_time_step + 1) + if time_steps > max_time_step: + max_time_step = time_steps + observations = np.random.randint( + 0, 255, size=(time_steps + 1,) + observation_shape).astype(np.uint8) + rewards = np.random.uniform(size=(time_steps,)).astype(np.float32) + actions = np.random.randint( + 0, n_actions, size=(time_steps,)).astype(np.int32) + infos = { + "a": np.random.uniform(size=(time_steps,)).astype(np.float32), + "b": np.random.uniform(size=(time_steps,)).astype(np.float32) + } + trajectories.append((observations, rewards, actions, infos)) + + # Now pad these trajectories. + padded_trajectories = ppo.pad_trajectories( + trajectories, boundary=bucket_length) + + # Expected padding. + i = 1 + while i * bucket_length < max_time_step: + i += 1 + expected_padding = i * bucket_length + + # Get the padded objects. + (pad_lengths, reward_mask, padded_observations, padded_actions, + padded_rewards, padded_infos) = padded_trajectories + + # Expectations on the padded shapes. + self.assertEqual(padded_observations.shape, ( + n_trajectories, + expected_padding + 1, + ) + observation_shape) + self.assertEqual(padded_actions.shape, (n_trajectories, expected_padding)) + self.assertEqual(padded_rewards.shape, (n_trajectories, expected_padding)) + self.assertEqual(reward_mask.shape, (n_trajectories, expected_padding)) + + self.assertEqual(padded_infos["a"].shape, + (n_trajectories, expected_padding)) + self.assertEqual(padded_infos["b"].shape, + (n_trajectories, expected_padding)) + + # Assert that the padding lengths and reward mask are consistent. + self.assertAllEqual( + np.full((n_trajectories,), expected_padding), + np.array(np.sum(reward_mask, axis=1)) + pad_lengths) + + def test_rewards_to_go(self): + rewards = np.array([ + [1, 2, 4, 8, 16, 32, 64, 128], + [1, 1, 1, 1, 1, 1, 1, 1], + ]) + + rewards_mask = np.array([ + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ]) + + gamma = 0.5 + + rewards_to_go = ppo.rewards_to_go(rewards, rewards_mask, gamma) + + self.assertAllEqual( + np.array([ + [5, 8, 12, 16, 16, 0, 0, 0], + [1.984375, 1.96875, 1.9375, 1.875, 1.75, 1.5, 1.0, 0], + ]), rewards_to_go) + + def test_rewards_to_go_really_long_sequences(self): + T = 1200 # pylint: disable=invalid-name + + rewards = np.random.uniform(1e-3, 1e-2, (1, T)) + + # Make a mask, clear out a fixed number `L` of 1s from the end. + L = 36 # pylint: disable=invalid-name + assert L < T + rewards_mask = np.ones_like(rewards) + rewards_mask[0, L:] = 0 + + gamma = 0.94 + + actual_r2g = ppo.rewards_to_go(rewards, rewards_mask, gamma).reshape(-1) + + # Let's compute r2g the slow way. + masked_rewards = (rewards_mask * rewards).reshape(-1) + expected_r2g = np.zeros_like(masked_rewards) + for t in range(T): + for j in range(t, T): + expected_r2g[t] += (gamma**(j - t)) * masked_rewards[j] + + self.assertAllClose(expected_r2g, actual_r2g) + + def test_value_loss(self): + rewards = np.array([ + [1, 2, 4, 8, 16, 32, 64, 128], + [1, 1, 1, 1, 1, 1, 1, 1], + ]) + + rewards_mask = np.array([ + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ]) + + gamma = 0.5 + + # Random observations and a value function that returns a constant value. + # NOTE: Observations have an extra time-step. + B, T = rewards.shape # pylint: disable=invalid-name + observation_shape = (210, 160, 3) # atari pong + random_observations = np.random.uniform(size=(B, T + 1) + observation_shape) + + def value_net_apply(observations, params, rng=None): + del params, rng + # pylint: disable=invalid-name + B, T_p_1, OBS = (observations.shape[0], observations.shape[1], + observations.shape[2:]) + del OBS + return np.ones((B, T_p_1)) + # pylint: enable=invalid-name + + value_prediction = value_net_apply(random_observations, []) + + with jax.disable_jit(): + (value_loss, _) = ppo.value_loss_given_predictions( + value_prediction, + rewards, + rewards_mask, + gamma) + + self.assertNear(53.3637084961, value_loss, 1e-6) + + def test_deltas(self): + rewards = np.array([ + [1, 2, 4, 8, 16, 32, 64, 128], + [1, 1, 1, 1, 1, 1, 1, 1], + ]) + + rewards_mask = np.array([ + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ]) + + B, T = rewards.shape # pylint: disable=invalid-name + + # Say, all predicted values are 1. + predicted_values = np.ones((B, T + 1)) + + gamma = 1.0 + + td_residuals = ppo.deltas(predicted_values, rewards, rewards_mask, gamma) + + # With V(s) being the same for all s, td_residuals should be + # equal to the rewards + (\gamma - 1)*v(s), masked in the right places. + truncated_pv = predicted_values[:, :-1] + masked_rewards = rewards * rewards_mask + expected_residuals = (masked_rewards + + (gamma - 1) * truncated_pv) * rewards_mask + self.assertAllEqual(expected_residuals, td_residuals) + + gamma = 0.5 + td_residuals = ppo.deltas(predicted_values, rewards, rewards_mask, gamma) + expected_residuals = (masked_rewards + + (gamma - 1) * truncated_pv) * rewards_mask + self.assertAllEqual(expected_residuals, td_residuals) + + def test_gae_advantages(self): + td_deltas = np.array([ + [1, 2, 4, 8, 16, 32, 64, 128], + [1, 1, 1, 1, 1, 1, 1, 1], + ]) + + rewards_mask = np.array([ + [1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ]) + + gamma = 0.5 + lambda_ = 1.0 + + expected_gae_advantages = np.array([ + [5, 8, 12, 16, 16, 0, 0, 0], + [1.984375, 1.96875, 1.9375, 1.875, 1.75, 1.5, 1.0, 0], + ]) + + gae_advantages = ppo.gae_advantages(td_deltas * rewards_mask, rewards_mask, + lambda_, gamma) + self.assertAllEqual(expected_gae_advantages, gae_advantages) + + gamma = 1.0 + lambda_ = 0.5 + + gae_advantages = ppo.gae_advantages(td_deltas * rewards_mask, rewards_mask, + lambda_, gamma) + self.assertAllEqual(expected_gae_advantages, gae_advantages) + + def test_chosen_probabs(self): + # Shape (2, 2, 3) + probab_observations = np.array( + [[[0.1, 0.2, 0.7], [0.4, 0.1, 0.5]], + [[0.3, 0.1, 0.6], [0.1, 0.1, 0.8]]] + ) + + # Shape (2, 2, 1) + actions = np.array([[1, 2], [0, 1]]) + + chosen_probabs = ppo.chosen_probabs(probab_observations, actions) + + self.assertAllEqual( + np.array([[0.2, 0.5], [0.3, 0.1]]), chosen_probabs) + + def test_compute_probab_ratios(self): + p_old = np.array([[ + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + [np.log(0.4), np.log(0.1), np.log(0.4), np.log(0.1)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + ], [ + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.1), np.log(0.4), np.log(0.4)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + ]]) + + p_new = np.array([[ + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.4), np.log(0.1), np.log(0.1), np.log(0.3)], + [np.log(0.1), np.log(0.2), np.log(0.1), np.log(0.6)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + ], [ + [np.log(0.1), np.log(0.2), np.log(0.1), np.log(0.6)], + [np.log(0.1), np.log(0.1), np.log(0.2), np.log(0.6)], + [np.log(0.3), np.log(0.1), np.log(0.3), np.log(0.3)], + [np.log(0.1), np.log(0.2), np.log(0.1), np.log(0.6)], + ]]) + + actions = np.array([[1, 2, 0, 1], [0, 3, 3, 0]]) + + mask = np.array([[1, 1, 0, 0], [1, 1, 1, 0]]) + + probab_ratios = ppo.compute_probab_ratios(p_new, p_old, actions, mask) + + self.assertAllClose( + np.array([ + [0.1 / 0.2, 0.1 / 0.4, 0.0, 0.0], + [0.1 / 0.3, 0.6 / 0.4, 0.3 / 0.1, 0.0], + ]), probab_ratios) + + def test_clipped_probab_ratios(self): + probab_ratios = np.array([ + [1.5, 1.0, 0.5, 0.7], + [2.5, 2.0, 0.1, 1.0], + ]) + + clipped_probab_ratios = ppo.clipped_probab_ratios(probab_ratios, 0.1) + + self.assertAllClose( + np.array([ + [1.1, 1.0, 0.9, 0.9], + [1.1, 1.1, 0.9, 1.0], + ]), clipped_probab_ratios) + + def test_clipped_objective(self): + probab_ratios = np.array([ + [1.5, 2.0, 0.5, 0.7], + [2.5, 2.0, 0.1, 1.0], + ]) + + advantages = np.array([ + [0.1, -0.1, 0.5, 0.7], + [2.0, -2.0, 2.0, 2.0], + ]) + + mask = np.array([[1, 1, 0, 0], [1, 1, 1, 0]]) + + epsilon = 0.1 + + clipped_probab_ratios = np.array([ + [1.1, 1.1, 0.9, 0.9], + [1.1, 1.1, 0.9, 1.0], + ]) + + unused_advantages_x_probab_ratios = np.array([ + [0.15, -0.2, 0.25, 0.49], + [5.00, -4.0, 0.20, 2.00] + ]) + + unused_advantages_x_clipped_probab_ratios = np.array([ + [0.11, -0.11, 0.45, 0.63], + [2.20, -2.20, .80, 2.00] + ]) + + unused_minimums = np.array([ + [0.11, -0.2, 0.25, 0.49], + [2.20, -4.0, 0.20, 2.00] + ]) + + # minimums * mask + objective = np.array([ + [0.11, -0.2, 0.0, 0.], + [2.20, -4.0, 0.2, 0.] + ]) + + # Assert that we computed things correctly in this test. + self.assertAllClose( + np.minimum(probab_ratios * advantages, + clipped_probab_ratios * advantages) * mask, + objective) + + self.assertAllClose( + objective, + ppo.clipped_objective(probab_ratios, advantages, mask, epsilon)) + + def test_combined_loss(self): + self.rng_key, key1, key2 = jax_random.split(self.rng_key, num=3) + + B, T, A, OBS = 2, 10, 2, (28, 28, 3) # pylint: disable=invalid-name + batch_observation_shape = (1, 1) + OBS + + net = ppo.policy_and_value_net( + n_controls=1, + n_actions=A, + vocab_size=None, + bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)], + two_towers=True, + ) + + old_params, _ = net.initialize_once( + batch_observation_shape, np.float32, key1) + new_params, state = net.initialize_once( + batch_observation_shape, np.float32, key2) + + # Generate a batch of observations. + + observations = np.random.uniform(size=(B, T + 1) + OBS) + actions = np.random.randint(0, A, size=(B, T + 1)) + rewards = np.random.uniform(0, 1, size=(B, T)) + mask = np.ones_like(rewards) + + # Just test that this computes at all. + (new_log_probabs, value_predictions_new) = ( + net(observations, params=new_params, state=state)) + (old_log_probabs, value_predictions_old) = ( + net(observations, params=old_params, state=state)) + + gamma = 0.99 + lambda_ = 0.95 + epsilon = 0.2 + c1 = 1.0 + c2 = 0.01 + + rewards_to_actions = np.eye(value_predictions_old.shape[1]) + (value_loss_1, _) = ppo.value_loss_given_predictions( + value_predictions_new, rewards, mask, gamma=gamma, + value_prediction_old=value_predictions_old, epsilon=epsilon) + (ppo_loss_1, _) = ppo.ppo_loss_given_predictions( + new_log_probabs, + old_log_probabs, + value_predictions_old, + actions, + rewards_to_actions, + rewards, + mask, + gamma=gamma, + lambda_=lambda_, + epsilon=epsilon) + + (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = ( + ppo.combined_loss(new_params, + old_log_probabs, + value_predictions_old, + net, + observations, + actions, + rewards_to_actions, + rewards, + mask, + gamma=gamma, + lambda_=lambda_, + epsilon=epsilon, + c1=c1, + c2=c2, + state=state) + ) + + # Test that these compute at all and are self consistent. + self.assertGreater(entropy_bonus, 0.0) + self.assertNear(value_loss_1, value_loss_2, 1e-6) + self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6) + self.assertNear(combined_loss, + ppo_loss_2 + (c1 * value_loss_2) - (c2 * entropy_bonus), + 1e-6) + + def test_masked_entropy(self): + # (2, 4+1, 4) + log_probs = np.array([[ + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + [np.log(0.4), np.log(0.1), np.log(0.4), np.log(0.1)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + ], [ + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.1), np.log(0.4), np.log(0.4)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + [np.log(0.1), np.log(0.2), np.log(0.6), np.log(0.1)], + [np.log(0.3), np.log(0.1), np.log(0.5), np.log(0.1)], + ]]) + + # (2, 4) + mask = np.array([ + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0] + ]) + + def plp(p): + return p * np.log(p) + + # Removing the last time-step and the masked stuff, gets us this. + filtered_log_probs = np.array([[ + [plp(0.1), plp(0.2), plp(0.6), plp(0.1)], + [plp(0.4), plp(0.1), plp(0.4), plp(0.1)], + [plp(0.3), plp(0.1), plp(0.5), plp(0.1)], + [plp(0.1), plp(0.1), plp(0.4), plp(0.4)], + [plp(0.3), plp(0.1), plp(0.5), plp(0.1)], + ]]) + + self.assertNear(ppo.masked_entropy(log_probs, mask), + -np.sum(filtered_log_probs) / 5.0, + 1e-6) + + def test_saves_and_restores_opt_state(self): + opt_state = 123 + state = 456 + epoch = 7 + opt_step = 89 + output_dir = self.get_temp_dir() + ppo.save_opt_state(output_dir, opt_state, state, epoch, opt_step) + restored_data = ppo.maybe_restore_opt_state(output_dir) + self.assertEqual(restored_data, (opt_state, state, epoch, opt_step)) + + def test_inits_policy_by_world_model_checkpoint(self): + transformer_kwargs = { + "d_model": 1, + "d_ff": 1, + "n_layers": 1, + "n_heads": 1, + "max_len": 128, + "mode": "train", + } + rng = jax_random.PRNGKey(123) + init_kwargs = { + "input_shapes": (1, 1), + "input_dtype": np.int32, + "rng": rng, + } + model_fn = functools.partial( + models.TransformerLM, vocab_size=4, **transformer_kwargs + ) + output_dir = self.get_temp_dir() + # Initialize a world model checkpoint by running the trainer. + trainer_lib.train( + output_dir, + model=model_fn, + inputs=functools.partial( + inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1) + ), + train_steps=1, + eval_steps=1, + ) + + policy = ppo.policy_and_value_net( + n_actions=3, + n_controls=2, + vocab_size=4, + bottom_layers_fn=functools.partial( + models.TransformerDecoder, **transformer_kwargs + ), + two_towers=False, + ) + (policy_params, policy_state) = policy.initialize_once(**init_kwargs) + + # Initialize policy parameters from world model parameters. + new_policy_params = ppo.init_policy_from_world_model_checkpoint( + policy_params, output_dir + ) + # Try to run the policy with new parameters. + observations = np.zeros((1, 100), dtype=np.int32) + policy(observations, params=new_policy_params, state=policy_state, rng=rng) + + def test_shuffled_index_batches_generates_valid_batch(self): + dataset_size = 16 + batch_size = 4 + stream = ppo.shuffled_index_batches(dataset_size, batch_size) + batch = next(stream) + self.assertEqual(batch.shape, (batch_size,)) + # Assert that all indices are different. + self.assertEqual(len(set(batch)), batch_size) + + def test_shuffled_index_batches_generates_all_indices(self): + dataset_size = 16 + batch_size = 4 + stream = ppo.shuffled_index_batches(dataset_size, batch_size) + indices = np.reshape( + list(itertools.islice(stream, dataset_size // batch_size)), -1 + ) + self.assertEqual(set(indices), set(range(dataset_size))) + + def test_shuffled_index_batches_gives_different_permutations(self): + dataset_size = 256 + batch_size = 8 + stream1 = ppo.shuffled_index_batches(dataset_size, batch_size) + stream2 = ppo.shuffled_index_batches(dataset_size, batch_size) + self.assertFalse(np.array_equal(next(stream1), next(stream2))) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/ppo_trainer.py b/trax/rl/ppo_trainer.py new file mode 100644 index 000000000..de6e49815 --- /dev/null +++ b/trax/rl/ppo_trainer.py @@ -0,0 +1,844 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PPO trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import os +import time + +from absl import logging +import gym +from jax import jit +from jax import numpy as np +from jax import random as jax_random +import numpy as onp +from tensor2tensor.envs import env_problem_utils +from tensor2tensor.envs import trajectory +from trax import jaxboard +from trax import models as trax_models +from trax import optimizers as trax_opt +from trax import trainer_lib +from trax.rl import base_trainer +from trax.rl import ppo +from trax.rl import serialization_utils +from trax.rl import space_serializer + +DEBUG_LOGGING = False +GAMMA = 0.99 +LAMBDA = 0.95 +EPSILON = 0.1 +EPOCHS = 50 # 100 +N_OPTIMIZER_STEPS = 100 +PRINT_EVERY_OPTIMIZER_STEP = 20 +BATCH_TRAJECTORIES = 32 + + +class PPO(base_trainer.BaseTrainer): + """PPO trainer.""" + + def __init__(self, + train_env, + eval_env, + output_dir, + policy_and_value_model=trax_models.FrameStackMLP, + policy_and_value_optimizer=functools.partial( + trax_opt.Adam, learning_rate=1e-3), + policy_and_value_two_towers=False, + policy_and_value_vocab_size=None, + n_optimizer_steps=N_OPTIMIZER_STEPS, + optimizer_batch_size=64, + print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, + target_kl=0.01, + boundary=20, + max_timestep=100, + max_timestep_eval=20000, + random_seed=None, + gamma=GAMMA, + lambda_=LAMBDA, + c1=1.0, + c2=0.01, + eval_every_n=1000, + save_every_n=1000, + done_frac_for_policy_save=0.5, + n_evals=1, + len_history_for_policy=4, + eval_temperatures=(1.0, 0.5), + separate_eval=True, + init_policy_from_world_model_output_dir=None, + **kwargs): + """Creates the PPO trainer. + + Args: + train_env: gym.Env to use for training. + eval_env: gym.Env to use for evaluation. + output_dir: Output dir. + policy_and_value_model: Function defining the policy and value network, + without the policy and value heads. + policy_and_value_optimizer: Function defining the optimizer. + policy_and_value_two_towers: Whether to use two separate models as the + policy and value networks. If False, share their parameters. + policy_and_value_vocab_size: Vocabulary size of a policy and value network + operating on serialized representation. If None, use raw continuous + representation. + n_optimizer_steps: Number of optimizer steps. + optimizer_batch_size: Batch size of an optimizer step. + print_every_optimizer_steps: How often to log during the policy + optimization process. + target_kl: Policy iteration early stopping. Set to infinity to disable + early stopping. + boundary: We pad trajectories at integer multiples of this number. + max_timestep: If set to an integer, maximum number of time-steps in a + trajectory. Used in the collect procedure. + max_timestep_eval: If set to an integer, maximum number of time-steps in + an evaluation trajectory. Used in the collect procedure. + random_seed: Random seed. + gamma: Reward discount factor. + lambda_: N-step TD-error discount factor in GAE. + c1: Value loss coefficient. + c2: Entropy loss coefficient. + eval_every_n: How frequently to eval the policy. + save_every_n: How frequently to save the policy. + done_frac_for_policy_save: Fraction of the trajectories that should be + done to checkpoint the policy. + n_evals: Number of times to evaluate. + len_history_for_policy: How much of history to give to the policy. + eval_temperatures: Sequence of temperatures to try for categorical + sampling during evaluation. + separate_eval: Whether to run separate evaluation using a set of + temperatures. If False, the training reward is reported as evaluation + reward with temperature 1.0. + init_policy_from_world_model_output_dir: Model output dir for initializing + the policy. If None, initialize randomly. + **kwargs: Additional keyword arguments passed to the base class. + """ + # Set in base class constructor. + self._train_env = None + self._should_reset = None + + super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) + + self._n_optimizer_steps = n_optimizer_steps + self._optimizer_batch_size = optimizer_batch_size + self._print_every_optimizer_steps = print_every_optimizer_steps + self._target_kl = target_kl + self._boundary = boundary + self._max_timestep = max_timestep + self._max_timestep_eval = max_timestep_eval + self._gamma = gamma + self._lambda_ = lambda_ + self._c1 = c1 + self._c2 = c2 + self._eval_every_n = eval_every_n + self._save_every_n = save_every_n + self._done_frac_for_policy_save = done_frac_for_policy_save + self._n_evals = n_evals + self._len_history_for_policy = len_history_for_policy + self._eval_temperatures = eval_temperatures + self._separate_eval = separate_eval + + action_space = self.train_env.action_space + assert isinstance( + action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)) + if isinstance(action_space, gym.spaces.Discrete): + n_actions = action_space.n + n_controls = 1 + else: + (n_controls,) = action_space.nvec.shape + assert n_controls > 0 + assert onp.min(action_space.nvec) == onp.max(action_space.nvec), ( + "Every control must have the same number of actions.") + n_actions = action_space.nvec[0] + self._n_actions = n_actions + self._n_controls = n_controls + + self._rng = trainer_lib.get_random_number_generator_and_set_seed( + random_seed) + self._rng, key1 = jax_random.split(self._rng, num=2) + + vocab_size = policy_and_value_vocab_size + self._serialized_sequence_policy = vocab_size is not None + if self._serialized_sequence_policy: + self._serialization_kwargs = self._init_serialization(vocab_size) + else: + self._serialization_kwargs = {} + + # Initialize the policy and value network. + policy_and_value_net = ppo.policy_and_value_net( + n_actions=n_actions, + n_controls=n_controls, + vocab_size=vocab_size, + bottom_layers_fn=policy_and_value_model, + two_towers=policy_and_value_two_towers, + ) + self._policy_and_value_net_apply = jit(policy_and_value_net) + (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype + policy_and_value_net_params, self._model_state = ( + policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype, key1)) + if init_policy_from_world_model_output_dir is not None: + policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( + policy_and_value_net_params, init_policy_from_world_model_output_dir + ) + + # Initialize the optimizer. + (policy_and_value_opt_state, self._policy_and_value_opt_update, + self._policy_and_value_get_params) = ppo.optimizer_fn( + policy_and_value_optimizer, policy_and_value_net_params) + + # Restore the optimizer state. + self._policy_and_value_opt_state = policy_and_value_opt_state + self._epoch = 0 + self._total_opt_step = 0 + self.update_optimization_state( + output_dir, policy_and_value_opt_state=policy_and_value_opt_state) + + # Create summary writers and history. + self._train_sw = jaxboard.SummaryWriter( + os.path.join(self._output_dir, "train")) + self._timing_sw = jaxboard.SummaryWriter( + os.path.join(self._output_dir, "timing")) + self._eval_sw = jaxboard.SummaryWriter( + os.path.join(self._output_dir, "eval")) + + self._n_trajectories_done = 0 + + self._last_saved_at = 0 + if self._async_mode: + logging.info("Saving model on startup to have a model policy file.") + self.save() + + self._rewards_to_actions = self._init_rewards_to_actions() + + def _init_serialization(self, vocab_size): + obs_serializer = space_serializer.create( + self.train_env.observation_space, vocab_size=vocab_size + ) + act_serializer = space_serializer.create( + self.train_env.action_space, vocab_size=vocab_size + ) + repr_length = ( + obs_serializer.representation_length + + act_serializer.representation_length + ) * (self._max_timestep + 1) + return { + "observation_serializer": obs_serializer, + "action_serializer": act_serializer, + "representation_length": repr_length, + } + + def _init_rewards_to_actions(self): + # Linear map from the reward sequence to the action sequence, used for + # scattering advantages over action log-probs and some other things. + # It has one more timestep at the end, so it's compatible with the value + # predictions. + if not self._serialized_sequence_policy: + rewards_to_actions = np.eye(self._max_timestep + 1)[:, None, :] + rewards_to_actions = np.broadcast_to( + rewards_to_actions, + (self._max_timestep + 1, self._n_controls, self._max_timestep + 1), + ) + return np.reshape(rewards_to_actions, (self._max_timestep + 1, -1)) + else: + return serialization_utils.rewards_to_actions_map( + n_timesteps=(self._max_timestep + 1), **self._serialization_kwargs + ) + + @property + def _batch_obs_shape_and_dtype(self): + if not self._serialized_sequence_policy: + # Batch Observations Shape = [1, 1] + OBS, because we will eventually call + # policy and value networks on shape [B, T] +_OBS + shape = (1, 1) + self.train_env.observation_space.shape + dtype = self.train_env.observation_space.dtype + else: + shape = (1, 1) + dtype = np.int32 + return (shape, dtype) + + # Maybe restore the optimization state. If there is nothing to restore, then + # epoch = 0 and policy_and_value_opt_state is returned as is. + def update_optimization_state(self, + output_dir, + policy_and_value_opt_state=None): + (self._policy_and_value_opt_state, self._model_state, self._epoch, + self._total_opt_step) = ppo.maybe_restore_opt_state( + output_dir, policy_and_value_opt_state, self._model_state) + + if self._epoch > 0: + logging.info("Restored parameters from epoch [%d]", self._epoch) + + @property + def train_env(self): + return self._train_env + + @train_env.setter + def train_env(self, new_train_env): + if self._train_env is not None: + + def assert_same_space(space1, space2): + assert space1.shape == space2.shape + assert space1.dtype == space2.dtype + + assert_same_space(new_train_env.observation_space, + self._train_env.observation_space) + assert_same_space(new_train_env.action_space, + self._train_env.action_space) + # We don't check the reward range, because PPO will work either way. + + self._train_env = new_train_env + self._should_reset = True + + @property + def epoch(self): + return self._epoch + + def collect_trajectories_async(self, + env, + train=True, + n_trajectories=1, + temperature=1.0): + """Collects trajectories in an async manner.""" + + assert self._async_mode + + # trajectories/train and trajectories/eval are the two subdirectories. + trajectory_dir = os.path.join(self._output_dir, "trajectories", + "train" if train else "eval") + epoch = self.epoch + + logging.info( + "Loading [%s] trajectories from dir [%s] for epoch [%s] and temperature" + " [%s]", n_trajectories, trajectory_dir, epoch, temperature) + + bt = trajectory.BatchTrajectory.load_from_directory( + trajectory_dir, + epoch=epoch, + temperature=temperature, + wait_forever=True, + n_trajectories=n_trajectories) + + if bt is None: + logging.error( + "Couldn't load [%s] trajectories from dir [%s] for epoch [%s] and " + "temperature [%s]", n_trajectories, trajectory_dir, epoch, + temperature) + assert bt + + # Doing this is important, since we want to modify `env` so that it looks + # like `env` was actually played and the trajectories came from it. + env.trajectories = bt + + trajs = env_problem_utils.get_completed_trajectories_from_env( + env, n_trajectories) + n_done = len(trajs) + timing_info = {} + return trajs, n_done, timing_info, self._model_state + + def collect_trajectories(self, + train=True, + temperature=1.0, + abort_fn=None, + raw_trajectory=False): + self._rng, key = jax_random.split(self._rng) + + env = self.train_env + max_timestep = self._max_timestep + should_reset = self._should_reset + if not train: # eval + env = self.eval_env + max_timestep = self._max_timestep_eval + should_reset = True + + n_trajectories = env.batch_size + + # If async, read the required trajectories for the epoch. + if self._async_mode: + trajs, n_done, timing_info, self._model_state = self.collect_trajectories_async( + env, + train=train, + n_trajectories=n_trajectories, + temperature=temperature) + else: + trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories( + env, + policy_fn=self._policy_fun, + n_trajectories=n_trajectories, + max_timestep=max_timestep, + state=self._model_state, + rng=key, + len_history_for_policy=self._len_history_for_policy, + boundary=self._boundary, + reset=should_reset, + temperature=temperature, + abort_fn=abort_fn, + raw_trajectory=raw_trajectory, + ) + + if train: + self._n_trajectories_done += n_done + + return trajs, n_done, timing_info, self._model_state + + def train_epoch(self, evaluate=True): + """Train one PPO epoch.""" + epoch_start_time = time.time() + + # Evaluate the policy. + policy_eval_start_time = time.time() + if evaluate and (self._epoch + 1) % self._eval_every_n == 0: + self._rng, key = jax_random.split(self._rng, num=2) + self.evaluate() + + policy_eval_time = ppo.get_time(policy_eval_start_time) + + trajectory_collection_start_time = time.time() + logging.vlog(1, "PPO epoch [% 6d]: collecting trajectories.", self._epoch) + self._rng, key = jax_random.split(self._rng) + trajs, _, timing_info, self._model_state = self.collect_trajectories( + train=True, temperature=1.0) + trajs = [(t[0], t[1], t[2], t[4]) for t in trajs] + self._should_reset = False + trajectory_collection_time = ppo.get_time(trajectory_collection_start_time) + + logging.vlog(1, "Collecting trajectories took %0.2f msec.", + trajectory_collection_time) + + rewards = np.array([np.sum(traj[2]) for traj in trajs]) + avg_reward = np.mean(rewards) + std_reward = np.std(rewards) + max_reward = np.max(rewards) + min_reward = np.min(rewards) + + self._train_sw.scalar( + "train/reward_mean_truncated", avg_reward, step=self._epoch) + if evaluate and not self._separate_eval: + metrics = {"raw": {1.0: {"mean": avg_reward, "std": std_reward}}} + ppo.write_eval_reward_summaries(metrics, self._eval_sw, self._epoch) + + logging.vlog(1, "Rewards avg=[%0.2f], max=[%0.2f], min=[%0.2f], all=%s", + avg_reward, max_reward, min_reward, + [float(np.sum(traj[2])) for traj in trajs]) + + logging.vlog(1, + "Trajectory Length average=[%0.2f], max=[%0.2f], min=[%0.2f]", + float(sum(len(traj[0]) for traj in trajs)) / len(trajs), + max(len(traj[0]) for traj in trajs), + min(len(traj[0]) for traj in trajs)) + logging.vlog(2, "Trajectory Lengths: %s", [len(traj[0]) for traj in trajs]) + + preprocessing_start_time = time.time() + (padded_observations, padded_actions, padded_rewards, reward_mask, + padded_infos) = self._preprocess_trajectories(trajs) + preprocessing_time = ppo.get_time(preprocessing_start_time) + + logging.vlog(1, "Preprocessing trajectories took %0.2f msec.", + ppo.get_time(preprocessing_start_time)) + logging.vlog(1, "Padded Observations' shape [%s]", + str(padded_observations.shape)) + logging.vlog(1, "Padded Actions' shape [%s]", str(padded_actions.shape)) + logging.vlog(1, "Padded Rewards' shape [%s]", str(padded_rewards.shape)) + + # Some assertions. + B, RT = padded_rewards.shape # pylint: disable=invalid-name + B, AT = padded_actions.shape # pylint: disable=invalid-name + assert (B, RT) == reward_mask.shape + assert B == padded_observations.shape[0] + + log_prob_recompute_start_time = time.time() + # TODO(pkozakowski): The following commented out code collects the network + # predictions made while stepping the environment and uses them in PPO + # training, so that we can use non-deterministic networks (e.g. with + # dropout). This does not work well with serialization, so instead we + # recompute all network predictions. Let's figure out a solution that will + # work with both serialized sequences and non-deterministic networks. + + # assert ("log_prob_actions" in padded_infos and + # "value_predictions" in padded_infos) + # These are the actual log-probabs and value predictions seen while picking + # the actions. + # actual_log_probabs_traj = padded_infos["log_prob_actions"] + # actual_value_predictions_traj = padded_infos["value_predictions"] + + # assert (B, T, C) == actual_log_probabs_traj.shape[:3] + # A = actual_log_probabs_traj.shape[3] # pylint: disable=invalid-name + # assert (B, T, 1) == actual_value_predictions_traj.shape + + del padded_infos + + # TODO(afrozm): log-probabs doesn't need to be (B, T+1, C, A) it can do with + # (B, T, C, A), so make that change throughout. + + # NOTE: We don't have the log-probabs and value-predictions for the last + # observation, so we re-calculate for everything, but use the original ones + # for all but the last time-step. + self._rng, key = jax_random.split(self._rng) + + log_probabs_traj, value_predictions_traj, self._model_state, _ = ( + self._get_predictions(padded_observations, self._model_state, rng=key)) + + assert (B, AT) == log_probabs_traj.shape[:2] + assert (B, AT) == value_predictions_traj.shape + + # TODO(pkozakowski): Commented out for the same reason as before. + + # Concatenate the last time-step's log-probabs and value predictions to the + # actual log-probabs and value predictions and use those going forward. + # log_probabs_traj = np.concatenate( + # (actual_log_probabs_traj, log_probabs_traj[:, -1:, :]), axis=1) + # value_predictions_traj = np.concatenate( + # (actual_value_predictions_traj, value_predictions_traj[:, -1:, :]), + # axis=1) + + log_prob_recompute_time = ppo.get_time(log_prob_recompute_start_time) + + # Compute value and ppo losses. + self._rng, key1 = jax_random.split(self._rng, num=2) + logging.vlog(2, "Starting to compute P&V loss.") + loss_compute_start_time = time.time() + (cur_combined_loss, component_losses, summaries, self._model_state) = ( + ppo.combined_loss( + self._policy_and_value_net_params, + log_probabs_traj, + value_predictions_traj, + self._policy_and_value_net_apply, + padded_observations, + padded_actions, + self._rewards_to_actions, + padded_rewards, + reward_mask, + gamma=self._gamma, + lambda_=self._lambda_, + c1=self._c1, + c2=self._c2, + state=self._model_state, + rng=key1)) + loss_compute_time = ppo.get_time(loss_compute_start_time) + (cur_ppo_loss, cur_value_loss, cur_entropy_bonus) = component_losses + logging.vlog( + 1, + "Calculating P&V loss [%10.2f(%10.2f, %10.2f, %10.2f)] took %0.2f msec.", + cur_combined_loss, cur_ppo_loss, cur_value_loss, cur_entropy_bonus, + ppo.get_time(loss_compute_start_time)) + + self._rng, key1 = jax_random.split(self._rng, num=2) + logging.vlog(1, "Policy and Value Optimization") + optimization_start_time = time.time() + keys = jax_random.split(key1, num=self._n_optimizer_steps) + opt_step = 0 + opt_batch_size = min(self._optimizer_batch_size, B) + index_batches = ppo.shuffled_index_batches( + dataset_size=B, batch_size=opt_batch_size + ) + for (index_batch, key) in zip(index_batches, keys): + k1, k2, k3 = jax_random.split(key, num=3) + t = time.time() + # Update the optimizer state on the sampled minibatch. + self._policy_and_value_opt_state, self._model_state = ( + ppo.policy_and_value_opt_step( + # We pass the optimizer slots between PPO epochs, so we need to + # pass the optimization step as well, so for example the + # bias-correction in Adam is calculated properly. Alternatively we + # could reset the slots and the step in every PPO epoch, but then + # the moment estimates in adaptive optimizers would never have + # enough time to warm up. So it makes sense to reuse the slots, + # even though we're optimizing a different loss in every new + # epoch. + self._total_opt_step, + self._policy_and_value_opt_state, + self._policy_and_value_opt_update, + self._policy_and_value_get_params, + self._policy_and_value_net_apply, + log_probabs_traj[index_batch], + value_predictions_traj[index_batch], + padded_observations[index_batch], + padded_actions[index_batch], + self._rewards_to_actions, + padded_rewards[index_batch], + reward_mask[index_batch], + c1=self._c1, + c2=self._c2, + gamma=self._gamma, + lambda_=self._lambda_, + state=self._model_state, + rng=k1)) + opt_step += 1 + self._total_opt_step += 1 + + # Compute the approx KL for early stopping. Use the whole dataset - as we + # only do inference, it should fit in the memory. + (log_probab_actions_new, _) = ( + self._policy_and_value_net_apply( + padded_observations, + params=self._policy_and_value_net_params, + state=self._model_state, + rng=k2)) + + action_mask = np.dot( + np.pad(reward_mask, ((0, 0), (0, 1))), self._rewards_to_actions + ) + approx_kl = ppo.approximate_kl(log_probab_actions_new, log_probabs_traj, + action_mask) + + early_stopping = approx_kl > 1.5 * self._target_kl + if early_stopping: + logging.vlog( + 1, "Early stopping policy and value optimization after %d steps, " + "with approx_kl: %0.2f", opt_step, approx_kl) + # We don't return right-away, we want the below to execute on the last + # iteration. + + t2 = time.time() + if (opt_step % self._print_every_optimizer_steps == 0 or + opt_step == self._n_optimizer_steps or early_stopping): + # Compute and log the loss. + (combined_loss, component_losses, _, self._model_state) = ( + ppo.combined_loss( + self._policy_and_value_net_params, + log_probabs_traj, + value_predictions_traj, + self._policy_and_value_net_apply, + padded_observations, + padded_actions, + self._rewards_to_actions, + padded_rewards, + reward_mask, + gamma=self._gamma, + lambda_=self._lambda_, + c1=self._c1, + c2=self._c2, + state=self._model_state, + rng=k3)) + logging.vlog(1, "One Policy and Value grad desc took: %0.2f msec", + ppo.get_time(t, t2)) + (ppo_loss, value_loss, entropy_bonus) = component_losses + logging.vlog( + 1, "Combined Loss(value, ppo, entropy_bonus) [%10.2f] ->" + " [%10.2f(%10.2f,%10.2f,%10.2f)]", cur_combined_loss, combined_loss, + ppo_loss, value_loss, entropy_bonus) + + if early_stopping: + break + + optimization_time = ppo.get_time(optimization_start_time) + + logging.vlog( + 1, "Total Combined Loss reduction [%0.2f]%%", + (100 * (cur_combined_loss - combined_loss) / np.abs(cur_combined_loss))) + + summaries.update({ + "n_optimizer_steps": opt_step, + "approx_kl": approx_kl, + }) + for (name, value) in summaries.items(): + self._train_sw.scalar("train/{}".format(name), value, step=self._epoch) + + logging.info( + "PPO epoch [% 6d], Reward[min, max, avg] [%5.2f,%5.2f,%5.2f], Combined" + " Loss(ppo, value, entropy) [%2.5f(%2.5f,%2.5f,%2.5f)]", self._epoch, + min_reward, max_reward, avg_reward, combined_loss, ppo_loss, value_loss, + entropy_bonus) + + # Bump the epoch counter before saving a checkpoint, so that a call to + # save() after the training loop is a no-op if a checkpoint was saved last + # epoch - otherwise it would bump the epoch counter on the checkpoint. + last_epoch = self._epoch + self._epoch += 1 + + # Save parameters every time we see the end of at least a fraction of batch + # number of trajectories that are done (not completed -- completed includes + # truncated and done). + # Also don't save too frequently, enforce a minimum gap. + policy_save_start_time = time.time() + # TODO(afrozm): Refactor to trax.save_state. + if (self._n_trajectories_done >= + self._done_frac_for_policy_save * self.train_env.batch_size and + self._epoch % self._save_every_n == 0) or self._async_mode: + self.save() + policy_save_time = ppo.get_time(policy_save_start_time) + + epoch_time = ppo.get_time(epoch_start_time) + + timing_dict = { + "epoch": epoch_time, + "policy_eval": policy_eval_time, + "trajectory_collection": trajectory_collection_time, + "preprocessing": preprocessing_time, + "log_prob_recompute": log_prob_recompute_time, + "loss_compute": loss_compute_time, + "optimization": optimization_time, + "policy_save": policy_save_time, + } + + timing_dict.update(timing_info) + + for k, v in timing_dict.items(): + self._timing_sw.scalar("timing/%s" % k, v, step=last_epoch) + + max_key_len = max(len(k) for k in timing_dict) + timing_info_list = [ + "%s : % 10.2f" % (k.rjust(max_key_len + 1), v) + for k, v in sorted(timing_dict.items()) + ] + logging.info("PPO epoch [% 6d], Timings: \n%s", last_epoch, + "\n".join(timing_info_list)) + + # Flush summary writers once in a while. + if self._epoch % 1000 == 0: + self.flush_summaries() + + def evaluate(self): + """Evaluate the agent.""" + if not self._separate_eval: + return + logging.vlog(1, "PPO epoch [% 6d]: evaluating policy.", self._epoch) + + processed_reward_sums = collections.defaultdict(list) + raw_reward_sums = collections.defaultdict(list) + for _ in range(self._n_evals): + for temperature in self._eval_temperatures: + trajs, _, _, self._model_state = self.collect_trajectories( + train=False, temperature=temperature) + + processed_reward_sums[temperature].extend( + sum(traj[2]) for traj in trajs) + raw_reward_sums[temperature].extend(sum(traj[3]) for traj in trajs) + + # Return the mean and standard deviation for each temperature. + def compute_stats(reward_dict): + return { + temperature: { # pylint: disable=g-complex-comprehension + "mean": onp.mean(rewards), + "std": onp.std(rewards) + } for (temperature, rewards) in reward_dict.items() + } + + reward_stats = { + "processed": compute_stats(processed_reward_sums), + "raw": compute_stats(raw_reward_sums), + } + + ppo.write_eval_reward_summaries( + reward_stats, self._eval_sw, epoch=self._epoch) + + def save(self): + """Save the agent parameters.""" + logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch) + ppo.save_opt_state( + self._output_dir, + self._policy_and_value_opt_state, + self._model_state, + self._epoch, + self._total_opt_step, + ) + # Reset this number. + self._n_trajectories_done = 0 + self._last_saved_at = self._epoch + + def flush_summaries(self): + self._train_sw.flush() + self._timing_sw.flush() + self._eval_sw.flush() + + @property + def _policy_and_value_net_params(self): + return self._policy_and_value_get_params(self._policy_and_value_opt_state) + + # Prepares the trajectories for policy training. + def _preprocess_trajectories(self, trajectories): + (_, reward_mask, observations, actions, rewards, infos) = ( + ppo.pad_trajectories(trajectories, boundary=self._max_timestep) + ) + assert self.train_env.observation_space.shape == observations.shape[2:] + if not self._serialized_sequence_policy: + # Add one timestep at the end, so it's compatible with + # self._rewards_to_actions. + pad_width = ((0, 0), (0, 1)) + ((0, 0),) * (actions.ndim - 2) + actions = np.pad(actions, pad_width) + actions = np.reshape(actions, (actions.shape[0], -1)) + else: + (observations, actions) = self._serialize_trajectories( + observations, actions, reward_mask + ) + return (observations, actions, rewards, reward_mask, infos) + + def _serialize_trajectories(self, observations, actions, reward_mask): + (reprs, _) = serialization_utils.serialize_observations_and_actions( + observations=observations, + actions=actions, + mask=reward_mask, + **self._serialization_kwargs + ) + # Mask out actions in the representation - otherwise we sample an action + # based on itself. + observations = reprs * serialization_utils.observation_mask( + **self._serialization_kwargs + ) + actions = reprs + return (observations, actions) + + # A function to get the policy and value predictions. + def _get_predictions(self, observations, state, rng=None): + """Returns log-probs, value predictions and key back.""" + key, key1 = jax_random.split(rng, num=2) + + (log_probs, value_preds) = self._policy_and_value_net_apply( + observations, params=self._policy_and_value_net_params, state=state, + rng=key1) + + return log_probs, value_preds, state, key + + def _policy_fun(self, observations, lengths, state, rng): + (batch_size, n_timesteps) = observations.shape[:2] + if self._serialized_sequence_policy: + actions = np.zeros( + (batch_size, n_timesteps - 1) + self.train_env.action_space.shape, + dtype=self.train_env.action_space.dtype, + ) + reward_mask = np.ones((batch_size, n_timesteps - 1), dtype=np.int32) + (observations, _) = self._serialize_trajectories( + observations, actions, reward_mask + ) + (log_probs, value_preds, state, rng) = self._get_predictions( + observations, state=state, rng=rng + ) + # We need the log_probs of those actions that correspond to the last actual + # time-step. + index = lengths - 1 # Since we want to index using lengths. + pred_index = self._calc_action_index(index) + log_probs = log_probs[ + np.arange(batch_size)[:, None, None], + pred_index[:, :, None], + np.arange(self._n_actions), + ] + value_preds = value_preds[np.arange(batch_size)[:, None], pred_index] + return (log_probs, value_preds, state, rng) + + def _calc_action_index(self, reward_index): + # Project the one-hot position in the reward sequence onto the action + # sequence to figure out which actions correspond to that position. + one_hot_index = np.eye(self._rewards_to_actions.shape[0])[reward_index] + action_mask = np.dot(one_hot_index, self._rewards_to_actions) + # Compute the number of symbols in an action. It's just the number of 1s in + # the mask. + action_length = int(np.sum(action_mask[0])) + # Argmax stops on the first occurrence, so we use it to find the first 1 in + # the mask. + action_start_index = np.argmax(action_mask, axis=1) + return action_start_index[:, None] + np.arange(action_length)[None, :] diff --git a/trax/rl/ppo_trainer_test.py b/trax/rl/ppo_trainer_test.py new file mode 100644 index 000000000..d6823b8ed --- /dev/null +++ b/trax/rl/ppo_trainer_test.py @@ -0,0 +1,306 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.ppo's training_loop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import functools +import itertools +import os +import tempfile + +import gin +import gym +import numpy as np + +from tensor2tensor.envs import gym_env_problem +from tensor2tensor.rl import gym_utils +from tensorflow import test +from tensorflow.io import gfile +from trax import inputs as trax_inputs +from trax import layers +from trax import learning_rate as lr +from trax import models +from trax import optimizers as trax_opt +from trax import trainer_lib +from trax.rl import envs # pylint: disable=unused-import +from trax.rl import ppo_trainer +from trax.rl import simulated_env_problem + + +class PpoTrainerTest(test.TestCase): + + def get_wrapped_env( + self, name="CartPole-v0", max_episode_steps=2, batch_size=1 + ): + wrapper_fn = functools.partial( + gym_utils.gym_env_wrapper, + **{ + "rl_env_max_episode_steps": max_episode_steps, + "maxskip_env": False, + "rendered_env": False, + "rendered_env_resize_to": None, # Do not resize frames + "sticky_actions": False, + "output_dtype": None, + }) + + return gym_env_problem.GymEnvProblem(base_env_name=name, + batch_size=batch_size, + env_wrapper_fn=wrapper_fn, + discrete_rewards=False) + + @contextlib.contextmanager + def tmp_dir(self): + tmp = tempfile.mkdtemp(dir=self.get_temp_dir()) + yield tmp + gfile.rmtree(tmp) + + def _make_trainer( + self, train_env, eval_env, output_dir, model=None, **kwargs + ): + if model is None: + model = lambda: layers.Serial(layers.Dense(1)) + return ppo_trainer.PPO( + train_env=train_env, + eval_env=eval_env, + policy_and_value_model=model, + n_optimizer_steps=1, + output_dir=output_dir, + random_seed=0, + max_timestep=3, + boundary=2, + save_every_n=1, + **kwargs + ) + + def test_training_loop_cartpole(self): + with self.tmp_dir() as output_dir: + trainer = self._make_trainer( + train_env=self.get_wrapped_env("CartPole-v0", 2), + eval_env=self.get_wrapped_env("CartPole-v0", 2), + output_dir=output_dir, + ) + trainer.training_loop(n_epochs=2) + + def test_training_loop_cartpole_transformer(self): + with self.tmp_dir() as output_dir: + trainer = self._make_trainer( + train_env=self.get_wrapped_env("CartPole-v0", 2), + eval_env=self.get_wrapped_env("CartPole-v0", 2), + output_dir=output_dir, + model=functools.partial( + models.TransformerDecoder, + d_model=1, + d_ff=1, + n_layers=1, + n_heads=1, + max_len=128, + mode="train", + ), + ) + trainer.training_loop(n_epochs=2) + + def test_training_loop_onlinetune(self): + with self.tmp_dir() as output_dir: + gin.bind_parameter("OnlineTuneEnv.model", functools.partial( + models.MLP, + n_hidden_layers=0, + n_output_classes=1, + )) + gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial( + trax_inputs.random_inputs, + input_shape=(1, 1), + input_dtype=np.float32, + output_shape=(1, 1), + output_dtype=np.float32, + )) + gin.bind_parameter("OnlineTuneEnv.train_steps", 1) + gin.bind_parameter("OnlineTuneEnv.eval_steps", 1) + gin.bind_parameter( + "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs")) + trainer = self._make_trainer( + train_env=self.get_wrapped_env("OnlineTuneEnv-v0", 1), + eval_env=self.get_wrapped_env("OnlineTuneEnv-v0", 1), + output_dir=output_dir, + ) + trainer.training_loop(n_epochs=1) + + def test_training_loop_simulated(self): + n_actions = 5 + history_shape = (3, 2, 3) + action_shape = (3,) + obs_shape = (3, 3) + reward_shape = (3, 1) + + def model(mode): + del mode + return layers.Serial( + layers.Parallel( + layers.Flatten(), # Observation stack. + layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. + ), + layers.Concatenate(), + layers.Dense(n_units=1), + layers.Dup(), + layers.Parallel( + layers.Dense(n_units=obs_shape[1]), # New observation. + None, # Reward. + ) + ) + + def inputs(n_devices): + del n_devices + stream = itertools.repeat( + (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32), + np.zeros(obs_shape), np.zeros(reward_shape)) + ) + return trax_inputs.Inputs( + train_stream=lambda: stream, + train_eval_stream=lambda: stream, + eval_stream=lambda: stream, + input_shape=(history_shape[1:], action_shape[1:]), + input_dtype=(np.float32, np.int32), + target_shape=(obs_shape[1:], reward_shape[1:]), + target_dtype=(np.float32, np.float32), + ) + + def loss(mask_id=None, has_weights=False): + """Cross-entropy loss as scalar compatible with Trax masking.""" + return layers.Serial( + # Swap from (pred-obs, pred-reward, target-obs, target-reward) + # to (pred-obs, target-obs, pred-reward, target-reward). + layers.Parallel([], layers.Swap()), + # Cross-entropy loss for obs, L2 loss on reward. + layers.Parallel(layers.CrossEntropyLossScalar(mask_id, has_weights), + layers.L2LossScalar(mask_id, has_weights)), + # Add both losses. + layers.Add(), + # Zero out in this test. + layers.MulConstant(constant=0.0) + ) + + with self.tmp_dir() as output_dir: + # Run fake training just to save the parameters. + trainer = trainer_lib.Trainer( + model=model, + loss_fn=loss, + inputs=inputs, + optimizer=trax_opt.SM3, + lr_schedule=lr.MultifactorSchedule, + output_dir=output_dir, + ) + trainer.train_epoch(epoch_steps=1, eval_steps=1) + + # Repeat the history over and over again. + stream = itertools.repeat(np.zeros(history_shape)) + env_fn = functools.partial( + simulated_env_problem.RawSimulatedEnvProblem, + model=model, + history_length=history_shape[1], + trajectory_length=3, + batch_size=history_shape[0], + observation_space=gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_shape[1],)), + action_space=gym.spaces.Discrete(n=n_actions), + reward_range=(-1, 1), + discrete_rewards=False, + history_stream=stream, + output_dir=output_dir, + ) + + trainer = self._make_trainer( + train_env=env_fn(), + eval_env=env_fn(), + output_dir=output_dir, + ) + trainer.training_loop(n_epochs=2) + + def test_restarts(self): + with self.tmp_dir() as output_dir: + train_env = self.get_wrapped_env("CartPole-v0", 2) + eval_env = self.get_wrapped_env("CartPole-v0", 2) + + # Train for 1 epoch and save. + trainer = self._make_trainer( + train_env=train_env, + eval_env=eval_env, + output_dir=output_dir, + ) + self.assertEqual(trainer.epoch, 0) + trainer.training_loop(n_epochs=1) + self.assertEqual(trainer.epoch, 1) + + # Restore from the saved state. + trainer = self._make_trainer( + train_env=train_env, + eval_env=eval_env, + output_dir=output_dir, + ) + self.assertEqual(trainer.epoch, 1) + # Check that we can continue training from the restored checkpoint. + trainer.training_loop(n_epochs=2) + self.assertEqual(trainer.epoch, 2) + + def test_training_loop_multi_control(self): + gym.register( + "FakeEnv-v0", + entry_point="trax.rl.envs.fake_env:FakeEnv", + kwargs={"n_actions": 3, "n_controls": 2}, + ) + with self.tmp_dir() as output_dir: + trainer = self._make_trainer( + train_env=self.get_wrapped_env("FakeEnv-v0", 2), + eval_env=self.get_wrapped_env("FakeEnv-v0", 2), + output_dir=output_dir, + ) + trainer.training_loop(n_epochs=2) + + def test_training_loop_cartpole_serialized(self): + gin.bind_parameter("BoxSpaceSerializer.precision", 1) + with self.tmp_dir() as output_dir: + trainer = self._make_trainer( + train_env=self.get_wrapped_env("CartPole-v0", 2), + eval_env=self.get_wrapped_env("CartPole-v0", 2), + output_dir=output_dir, + model=functools.partial( + models.TransformerDecoder, + d_model=1, + d_ff=1, + n_layers=1, + n_heads=1, + max_len=1024, + mode="train", + ), + policy_and_value_vocab_size=4, + ) + trainer.training_loop(n_epochs=2) + + def test_training_loop_cartpole_minibatch(self): + with self.tmp_dir() as output_dir: + trainer = self._make_trainer( + train_env=self.get_wrapped_env("CartPole-v0", 2, batch_size=4), + eval_env=self.get_wrapped_env("CartPole-v0", 2), + output_dir=output_dir, + optimizer_batch_size=2, + ) + trainer.training_loop(n_epochs=2) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/serialization_utils.py b/trax/rl/serialization_utils.py new file mode 100644 index 000000000..51de81002 --- /dev/null +++ b/trax/rl/serialization_utils.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for serializing trajectories into discrete sequences.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def serialize_observations_and_actions( + observations, + actions, + mask, + observation_serializer, + action_serializer, + representation_length, +): + """Serializes observations and actions into a discrete sequence. + + Args: + observations: Array (B, T + 1, ...), of observations, where B is the batch + size and T is the number of timesteps excluding the last observation. + actions: Array (B, T, ...) of actions. + mask: Binary array (B, T) indicating where each sequence ends (1s while + it continues). + observation_serializer: SpaceSerializer for observations. + action_serializer: SpaceSerializer for actions. + representation_length: Number of symbols in the serialized sequence. The + sequence is padded up to this number. + Returns: + Pair (representation, mask), where representation is the serialized sequence + of shape (B, R) where R = representation_length, and mask is a binary array + of shape (B, R) indicating where each sequence ends. + """ + (batch_size, n_timesteps) = actions.shape[:2] + assert observations.shape[:2] == (batch_size, n_timesteps + 1) + assert mask.shape == (batch_size, n_timesteps) + + reprs = [] + for t in range(n_timesteps): + reprs.append(observation_serializer.serialize(observations[:, t, ...])) + reprs.append(action_serializer.serialize(actions[:, t, ...])) + reprs.append(observation_serializer.serialize(observations[:, -1, ...])) + reprs = np.concatenate(reprs, axis=1) + assert reprs.shape[1] <= representation_length + reprs = np.pad( + reprs, + pad_width=((0, 0), (0, representation_length - reprs.shape[1])), + mode="constant", + ) + + obs_repr_length = observation_serializer.representation_length + act_repr_length = action_serializer.representation_length + step_repr_length = obs_repr_length + act_repr_length + seq_lengths = np.sum(mask, axis=1).astype(np.int32) + repr_lengths = seq_lengths * step_repr_length + obs_repr_length + repr_mask = np.zeros((batch_size, representation_length), dtype=np.int32) + for (i, repr_length) in enumerate(repr_lengths): + repr_mask[i, :repr_length] = 1 + + return (reprs, repr_mask) + + +def observation_mask( + observation_serializer, action_serializer, representation_length +): + """Calculates an observation mask for a serialized sequence. + + Args: + observation_serializer: SpaceSerializer for observations. + action_serializer: SpaceSerializer for actions. + representation_length: Number of symbols in the serialized sequence. The + mask is padded up to this number. + + Returns: + Binary mask indicating which symbols in the representation correspond to + observations. + """ + mask = np.zeros(representation_length, dtype=np.int32) + obs_repr_length = observation_serializer.representation_length + step_repr_length = obs_repr_length + action_serializer.representation_length + for step_start_index in range(0, representation_length, step_repr_length): + mask[step_start_index:(step_start_index + obs_repr_length)] = 1 + return mask + + +def action_mask( + observation_serializer, action_serializer, representation_length +): + """Calculates an action mask for a serialized sequence. + + Args: + observation_serializer: SpaceSerializer for observations. + action_serializer: SpaceSerializer for actions. + representation_length: Number of symbols in the serialized sequence. The + mask is padded up to this number. + + Returns: + Binary mask indicating which symbols in the representation correspond to + actions. + """ + return 1 - observation_mask( + observation_serializer, action_serializer, representation_length + ) + + +def significance_map( + observation_serializer, action_serializer, representation_length +): + """Calculates a significance map for the entire serialized sequence. + + See SpaceSerializer.significance_map. + + Args: + observation_serializer: SpaceSerializer for observations. + action_serializer: SpaceSerializer for actions. + representation_length: Number of symbols in the serialized sequence. The + significance map is padded up to this number. + + Returns: + Significance map for the entire serialized sequence. + """ + sig_map = np.zeros(representation_length, dtype=np.int32) + obs_repr_length = observation_serializer.representation_length + act_repr_length = action_serializer.representation_length + step_repr_length = obs_repr_length + act_repr_length + for step_start_index in range(0, representation_length, step_repr_length): + act_start_index = step_start_index + obs_repr_length + step_end_index = step_start_index + step_repr_length + limit = representation_length - step_start_index + sig_map[step_start_index:act_start_index] = ( + observation_serializer.significance_map[:limit] + ) + limit = representation_length - act_start_index + sig_map[act_start_index:step_end_index] = ( + action_serializer.significance_map[:limit] + ) + return sig_map + + +def rewards_to_actions_map( + observation_serializer, + action_serializer, + n_timesteps, + representation_length, +): + """Calculates a mapping between the rewards and the serialized sequence. + + Used to broadcast advantages over the log-probabilities of corresponding + actions. + + Args: + observation_serializer: SpaceSerializer for observations. + action_serializer: SpaceSerializer for actions. + n_timesteps: Number of timesteps (length of the reward sequence). + representation_length: Number of symbols in the serialized sequence. + + Returns: + Array (T, R) translating from the reward sequence to actions in the + representation. + """ + r2a_map = np.zeros((n_timesteps, representation_length)) + obs_repr_length = observation_serializer.representation_length + act_repr_length = action_serializer.representation_length + step_repr_length = obs_repr_length + act_repr_length + for t in range(n_timesteps): + act_start_index = t * step_repr_length + obs_repr_length + r2a_map[t, act_start_index:(act_start_index + act_repr_length)] = 1 + return r2a_map diff --git a/trax/rl/serialization_utils_test.py b/trax/rl/serialization_utils_test.py new file mode 100644 index 000000000..c35c66c18 --- /dev/null +++ b/trax/rl/serialization_utils_test.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.serialization_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin +import gym +import numpy as np +from tensorflow import test +from trax.rl import serialization_utils +from trax.rl import space_serializer + + +class SerializationTest(test.TestCase): + + def setUp(self): + super(SerializationTest, self).setUp() + self._serializer = space_serializer.create( + gym.spaces.Discrete(2), vocab_size=2 + ) + self._repr_length = 100 + self._serialization_utils_kwargs = { + "observation_serializer": self._serializer, + "action_serializer": self._serializer, + "representation_length": self._repr_length, + } + + def test_serializes_observations_and_actions(self): + (reprs, mask) = serialization_utils.serialize_observations_and_actions( + observations=np.array([[0, 1]]), + actions=np.array([[0]]), + mask=np.array([[1]]), + **self._serialization_utils_kwargs + ) + self.assertEqual(reprs.shape, (1, self._repr_length)) + self.assertEqual(mask.shape, (1, self._repr_length)) + self.assertGreater(np.sum(mask), 0) + self.assertEqual(np.max(mask), 1) + + def test_masks_length(self): + (reprs, mask) = serialization_utils.serialize_observations_and_actions( + observations=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 1]]), + actions=np.array([[0, 0], [0, 1], [0, 0]]), + mask=np.array([[1, 0], [1, 1], [1, 1]]), + **self._serialization_utils_kwargs + ) + # Trajectories 1 and 2 are longer than 0. + self.assertGreater(np.sum(mask[1]), np.sum(mask[0])) + self.assertGreater(np.sum(mask[2]), np.sum(mask[0])) + # Trajectory 0 is a common prefix of 1 and 2. 1 and 2 are different. + np.testing.assert_array_equal(reprs[0] * mask[0], reprs[1] * mask[0]) + np.testing.assert_array_equal(reprs[0] * mask[0], reprs[2] * mask[0]) + self.assertFalse(np.array_equal(reprs[1] * mask[1], reprs[2] * mask[2])) + # Trajectories should be padded with 0s. + np.testing.assert_array_equal( + reprs * (1 - mask), np.zeros((3, self._repr_length)) + ) + + def test_observation_and_action_masks_are_valid_and_complementary(self): + obs_mask = serialization_utils.observation_mask( + **self._serialization_utils_kwargs + ) + self.assertEqual(obs_mask.shape, (self._repr_length,)) + self.assertEqual(np.min(obs_mask), 0) + self.assertEqual(np.max(obs_mask), 1) + + act_mask = serialization_utils.action_mask( + **self._serialization_utils_kwargs + ) + self.assertEqual(act_mask.shape, (self._repr_length,)) + self.assertEqual(np.min(act_mask), 0) + self.assertEqual(np.max(act_mask), 1) + + np.testing.assert_array_equal( + obs_mask + act_mask, np.ones(self._repr_length) + ) + + def test_masks_observations(self): + (reprs, _) = serialization_utils.serialize_observations_and_actions( + # Observations are different, actions are the same. + observations=np.array([[0, 1], [1, 1]]), + actions=np.array([[0], [0]]), + mask=np.array([[1], [1]]), + **self._serialization_utils_kwargs + ) + obs_mask = serialization_utils.observation_mask( + **self._serialization_utils_kwargs + ) + act_mask = serialization_utils.action_mask( + **self._serialization_utils_kwargs + ) + + self.assertFalse(np.array_equal(reprs[0] * obs_mask, reprs[1] * obs_mask)) + np.testing.assert_array_equal(reprs[0] * act_mask, reprs[1] * act_mask) + + def test_masks_actions(self): + (reprs, _) = serialization_utils.serialize_observations_and_actions( + # Observations are the same, actions are different. + observations=np.array([[0, 1], [0, 1]]), + actions=np.array([[0], [1]]), + mask=np.array([[1], [1]]), + **self._serialization_utils_kwargs + ) + obs_mask = serialization_utils.observation_mask( + **self._serialization_utils_kwargs + ) + act_mask = serialization_utils.action_mask( + **self._serialization_utils_kwargs + ) + + np.testing.assert_array_equal(reprs[0] * obs_mask, reprs[1] * obs_mask) + self.assertFalse(np.array_equal(reprs[0] * act_mask, reprs[1] * act_mask)) + + def test_significance_map(self): + gin.bind_parameter("BoxSpaceSerializer.precision", 3) + significance_map = serialization_utils.significance_map( + observation_serializer=space_serializer.create( + gym.spaces.Box(low=0, high=1, shape=(2,)), vocab_size=2 + ), + action_serializer=space_serializer.create( + gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2 + ), + representation_length=20, + ) + np.testing.assert_array_equal( + significance_map, + # obs1, act1, obs2, act2, obs3 cut after 4th symbol. + [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0], + ) + + def test_rewards_to_actions_map(self): + rewards = np.array([1, 2, 3]) + r2a_map = serialization_utils.rewards_to_actions_map( + observation_serializer=space_serializer.create( + gym.spaces.MultiDiscrete(nvec=[2, 2, 2]), vocab_size=2 + ), + action_serializer=space_serializer.create( + gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2 + ), + n_timesteps=len(rewards), + representation_length=16, + ) + broadcast_rewards = np.dot(rewards, r2a_map) + np.testing.assert_array_equal( + broadcast_rewards, + # obs1, act1, obs2, act2, obs3 cut after 1st symbol. + [0, 0, 0, 1, 1, 0, 0, 0, 2, 2, 0, 0, 0, 3, 3, 0], + ) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/simple.py b/trax/rl/simple.py new file mode 100644 index 000000000..6bff22c5b --- /dev/null +++ b/trax/rl/simple.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SimPLe helper functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import os +import random + +from absl import logging +import numpy as np +from tensor2tensor.envs import env_problem_utils +from tensor2tensor.envs import trajectory +from tensorflow.io import gfile +from trax import utils + + +def load_trajectories(trajectory_dir, eval_frac): + """Loads trajectories from a possibly nested directory of pickles.""" + pkl_module = utils.get_pickle_module() + train_trajectories = [] + eval_trajectories = [] + # Search the entire directory subtree for trajectories. + for (subdir, _, filenames) in gfile.walk(trajectory_dir): + for filename in filenames: + shard_path = os.path.join(subdir, filename) + try: + with gfile.GFile(shard_path, "rb") as f: + trajectories = pkl_module.load(f) + pivot = int(len(trajectories) * (1 - eval_frac)) + train_trajectories.extend(trajectories[:pivot]) + eval_trajectories.extend(trajectories[pivot:]) + except EOFError: + logging.warning( + "Could not load trajectories from a corrupted shard %s.", + shard_path, + ) + assert train_trajectories, "Can't find training data in %s" % trajectory_dir + assert eval_trajectories, "Can't find evaluation data in %s" % trajectory_dir + return train_trajectories, eval_trajectories + + +def generate_examples(trajectories, trajectory_to_training_examples_fn): + """Generates an infinite stream of shuffled examples out of trajectories.""" + examples = [ + example # pylint: disable=g-complex-comprehension + for trajectory_examples in map( + trajectory_to_training_examples_fn, trajectories) + for example in trajectory_examples + ] + assert examples + while True: + random.shuffle(examples) + for example in examples: + yield example + + +def mix_streams(stream1, stream2, mix_prob): + """Mixes two streams together with a fixed probability.""" + while True: + # In the corner cases (mix_prob = 0 or 1) mixing the other stream never + # happens, because random() samples from the semi-open interval [0, 1). + if random.random() < mix_prob: + yield next(stream1) + else: + yield next(stream2) + + +def batch_stream(stream, batch_size): + """Batches a stream of training examples.""" + def make_batch(examples): + """Stacks a structure of numpy arrays nested in lists/tuples.""" + assert examples + if isinstance(examples[0], (list, tuple)): + return type(examples[0])( + make_batch([example[i] for example in examples]) + for i in range(len(examples[0])) + ) + else: + return np.stack(examples, axis=0) + + # Take consecutive batches from an infinite stream. This way there are no + # incomplete batches. We might get duplicate examples in the same batch, but + # that should be very rare. + while True: + yield make_batch(list(itertools.islice(stream, batch_size))) + + +# TODO(pkozakowski): This is mostly a simplified version of +# env_problem_utils.play_env_problem_with_policy, generalized to work with +# policies not being neural networks. Another difference is that it always +# collects exactly one trajectory from each environment in the batch. Unify if +# possible. +def play_env_problem(env, policy_fn): + """Plays an EnvProblem using a given policy function.""" + trajectories = [trajectory.Trajectory() for _ in range(env.batch_size)] + observations = env.reset() + for (traj, observation) in zip(trajectories, observations): + traj.add_time_step(observation=observation) + + done_so_far = np.array([False] * env.batch_size) + while not np.all(done_so_far): + padded_observations, _ = env.trajectories.observations_np( + len_history_for_policy=None) + actions = policy_fn(padded_observations) + (observations, rewards, dones, _) = env.step(actions) + for (traj, observation, action, reward, done) in zip( + trajectories, observations, actions, rewards, dones + ): + if not traj.done: + traj.change_last_time_step(action=action) + traj.add_time_step( + observation=observation, raw_reward=reward, done=done) + env.reset(indices=env_problem_utils.done_indices(dones)) + done_so_far = np.logical_or(done_so_far, dones) + return trajectories + + +def calculate_observation_error(real_trajectories, sim_trajectories): + """Calculates MSE of observations in two trajectories.""" + def pad_or_truncate(observations, desired_length): + (current_length, _) = observations.shape + if current_length < desired_length: + return np.pad( + observations, + pad_width=((0, desired_length - current_length), (0, 0)), + mode="edge", + ) + else: + return observations[:desired_length, :] + + def calculate_for_single_pair(real_trajectory, sim_trajectory): + real_obs = real_trajectory.observations_np + sim_obs = pad_or_truncate( + sim_trajectory.observations_np, real_trajectory.num_time_steps) + return np.sum((real_obs - sim_obs) ** 2, axis=0) + + return np.mean([ + calculate_for_single_pair(real_traj, sim_traj) + for (real_traj, sim_traj) in zip(real_trajectories, sim_trajectories) + ], axis=0) + + +def plot_observation_error(real_trajectories, sim_trajectories, mpl_plt): + """Plots observations from two trajectories on the same graph.""" + assert len(real_trajectories) == len(sim_trajectories) + assert real_trajectories + obs_dim = real_trajectories[0].last_time_step.observation.shape[0] + (w, h) = mpl_plt.rcParams["figure.figsize"] + ncols = len(real_trajectories) + nrows = obs_dim + (_, axes) = mpl_plt.subplots( + nrows, ncols, figsize=(w * ncols, h * nrows)) + for (traj_index, (real_traj, sim_traj)) in enumerate( + zip(real_trajectories, sim_trajectories) + ): + for dim_index in range(obs_dim): + for (traj, label) in ((real_traj, "real"), (sim_traj, "simulated")): + obs = traj.observations_np + ax = axes[dim_index, traj_index] + ax.set_title("trajectory {}, observation dimension {}".format( + traj_index, dim_index)) + ax.plot(np.arange(obs.shape[0]), obs[:, dim_index], label=label) + ax.legend() + + +class ReplayPolicy(object): + """Policy function repeating actions from a given batch of trajectories.""" + + def __init__(self, trajectories, out_of_bounds_action): + """Creates ReplayPolicy. + + Args: + trajectories: Batch of trajectories to repeat actions from. + out_of_bounds_action: Action to play after the replayed trajectory ends. + """ + self._trajectories = trajectories + self._out_of_bounds_action = out_of_bounds_action + self._step = 0 + + def __call__(self, observations): + del observations + + def get_action(traj): + action = None + if self._step < traj.num_time_steps: + action = traj.time_steps[self._step].action + # PS: action can still be None, if this is the last time-step in traj. + return action if action is not None else self._out_of_bounds_action + actions = np.array(list(map(get_action, self._trajectories))) + self._step += 1 + return actions + + +def evaluate_model(sim_env, real_trajectories, mpl_plt, n_to_plot=3): + """Reports the observation error metric and the corresponding plot.""" + if len(sim_env.observation_space.shape) != 1: + logging.warning( + "Could not evaluate the model - only environments with vector " + "observation spaces are supported." + ) + return + + assert len(real_trajectories) == sim_env.batch_size + + policy_fn = ReplayPolicy( + real_trajectories, + # Does not matter which action we play after the real trajetory ends, we + # cut the simulated one to match the real one anyway. + out_of_bounds_action=sim_env.action_space.sample(), + ) + + sim_trajectories = play_env_problem(sim_env, policy_fn) + obs_errors = calculate_observation_error(real_trajectories, sim_trajectories) + plot_observation_error( + real_trajectories[:n_to_plot], sim_trajectories[:n_to_plot], mpl_plt) + return { + "observation_error/{}".format(i): obs_error + for (i, obs_error) in enumerate(obs_errors) + } diff --git a/trax/rl/simple_test.py b/trax/rl/simple_test.py new file mode 100644 index 000000000..287fdc106 --- /dev/null +++ b/trax/rl/simple_test.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.simple.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import os + +import gin +import gym +from matplotlib import pyplot as plt +import mock +import numpy as np + +from tensor2tensor.envs import trajectory +from tensorflow import test +from tensorflow.io import gfile +from trax import backend +from trax import trainer_lib +from trax import utils +from trax.rl import simple +from trax.rl import simulated_env_problem +from trax.rl import space_serializer # pylint: disable=unused-import + + +class SimpleTest(test.TestCase): + + def _make_singleton_trajectory(self, observation): + t = trajectory.Trajectory() + t.add_time_step(observation=observation) + return t + + def _dump_trajectory_pickle(self, observations, path): + pkl_module = utils.get_pickle_module() + trajectories = list(map(self._make_singleton_trajectory, observations)) + with gfile.GFile(path, "wb") as f: + pkl_module.dump(trajectories, f) + + def test_loads_trajectories(self): + temp_dir = self.get_temp_dir() + # Dump two trajectory pickles with given observations. + self._dump_trajectory_pickle( + observations=[0, 1, 2, 3], path=os.path.join(temp_dir, "0.pkl")) + self._dump_trajectory_pickle( + observations=[4, 5, 6, 7], path=os.path.join(temp_dir, "1.pkl")) + (train_trajs, eval_trajs) = simple.load_trajectories( + temp_dir, eval_frac=0.25) + extract_obs = lambda t: t.last_time_step.observation + # The order of pickles is undefined, so we compare sets. + actual_train_obs = set(map(extract_obs, train_trajs)) + actual_eval_obs = set(map(extract_obs, eval_trajs)) + + # First 3 trajectories from each pickle go to train, the last one to eval. + expected_train_obs = {0, 1, 2, 4, 5, 6} + expected_eval_obs = {3, 7} + self.assertEqual(actual_train_obs, expected_train_obs) + self.assertEqual(actual_eval_obs, expected_eval_obs) + + def test_generates_examples(self): + observations = [0, 1, 2, 3] + trajectories = map(self._make_singleton_trajectory, observations) + trajectory_to_training_examples = lambda t: [t.last_time_step.observation] + stream = simple.generate_examples( + trajectories, trajectory_to_training_examples) + + # The examples are shuffled, so we compare sets. + self.assertEqual( + set(itertools.islice(stream, len(observations))), set(observations)) + # The stream is infinite, so we should be able to take a next element. + self.assertIn(next(stream), observations) + + def test_mixes_streams_with_prob_one(self): + # Mix infinite streams of 0s and 1s. + stream = simple.mix_streams( + itertools.repeat(0), itertools.repeat(1), mix_prob=1.0) + # Mixed stream should have only 0s. + self.assertEqual(set(itertools.islice(stream, 100)), {0}) + + def test_mixes_streams_with_prob_zero(self): + stream = simple.mix_streams( + itertools.repeat(0), itertools.repeat(1), mix_prob=0.0) + # Mixed stream should have only 1s. + self.assertEqual(set(itertools.islice(stream, 100)), {1}) + + def test_mixes_streams_with_prob_half(self): + stream = simple.mix_streams( + itertools.repeat(0), itertools.repeat(1), mix_prob=0.5) + # Mixed stream should have both 0s and 1s. + self.assertEqual(set(itertools.islice(stream, 100)), {0, 1}) + + def test_batches_stream(self): + stream = iter([(0, 1), (2, 3), (4, 5), (6, 7)]) + batched_stream = simple.batch_stream(stream, batch_size=2) + np.testing.assert_equal( + next(batched_stream), (np.array([0, 2]), np.array([1, 3]))) + np.testing.assert_equal( + next(batched_stream), (np.array([4, 6]), np.array([5, 7]))) + + def test_plays_env_problem(self): + # Shape: (time, trajectory). + observations = np.array([[0, 1], [2, 3], [4, 5]]) + rewards = np.array([[0, 1], [1, 0]]) + actions = np.array([[1, 2], [2, 0]]) + # We end the second environment 2 times, but we shouldn't collect the second + # trajectory. + dones = np.array([[False, True], [True, True]]) + infos = [{}, {}] + + mock_env = mock.MagicMock() + mock_env.batch_size = 2 + # (observations, lengths) + mock_env.trajectories.observations_np.return_value = (None, None) + mock_env.reset.return_value = observations[0] + mock_env.step.side_effect = zip(observations[1:], rewards, dones, infos) + + mock_policy_fn = mock.MagicMock() + mock_policy_fn.side_effect = actions + + trajectories = simple.play_env_problem(mock_env, mock_policy_fn) + self.assertEqual(len(trajectories), 2) + expected_lengths = [3, 2] + for (i, (traj, expected_length)) in enumerate( + zip(trajectories, expected_lengths)): + self.assertEqual(traj.num_time_steps, expected_length) + np.testing.assert_array_equal( + traj.observations_np, observations[:expected_length, i]) + np.testing.assert_array_equal( + traj.raw_rewards_np, rewards[:(expected_length - 1), i]) + np.testing.assert_array_equal( + traj.actions_np, actions[:(expected_length - 1), i]) + + def _make_trajectory(self, observations=None, actions=None): + t = trajectory.Trajectory() + if observations is None: + observations = itertools.repeat(None) + if actions is None: + actions = itertools.repeat(None) + for (observation, action) in zip(observations, actions): + t.add_time_step(observation=observation, action=action) + return t + + def test_replay_policy(self): + trajectories = [ + self._make_trajectory(actions=actions) + for actions in map(np.array, [[1, 2], [3]]) + ] + policy_fn = simple.ReplayPolicy(trajectories, out_of_bounds_action=0) + np.testing.assert_array_equal(policy_fn(None), [1, 3]) + np.testing.assert_array_equal(policy_fn(None), [2, 0]) + + def test_observation_error_zero_for_same_trajectories(self): + observations = np.array([[0], [2], [1]]) + (traj1, traj2) = map(self._make_trajectory, (observations, observations)) + error = simple.calculate_observation_error([traj1], [traj2]) + np.testing.assert_array_almost_equal(error, [0]) + + def test_observation_error_positive_for_different_trajectories(self): + observations1 = np.array([[1], [2], [3]]) + observations2 = np.array([[0], [2], [3]]) + (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) + error = simple.calculate_observation_error([traj1], [traj2]) + np.testing.assert_array_less([0], error) + + def test_observation_error_dims_correspond_to_observation_dims(self): + observations1 = np.array([[0, 1, 0], [0, 2, 0], [0, 3, 0]]) + observations2 = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) + error = simple.calculate_observation_error([traj1], [traj2]) + self.assertEqual(error.shape, (3,)) + np.testing.assert_array_almost_equal(error[0], 0) + self.assertFalse(np.allclose(error[1], 0)) + np.testing.assert_array_almost_equal(error[2], 0) + + def test_observation_error_increases_with_distance(self): + observations_zero = np.array([[0], [0], [0]]) + observations_positive = np.array([[3], [2], [1]]) + (traj_zero, traj_positive, traj_negative) = map( + self._make_trajectory, + (observations_zero, observations_positive, -observations_positive), + ) + error_small = simple.calculate_observation_error( + [traj_zero], [traj_positive]) + error_big = simple.calculate_observation_error( + [traj_positive], [traj_negative]) + np.testing.assert_array_less(error_small, error_big) + + def test_observation_error_increases_with_real_trajectory_length(self): + observations_real_short = np.array([[1], [2]]) + observations_real_long = np.array([[1], [2], [3]]) + observations_sim = np.array([[0], [1]]) + (traj_real_short, traj_real_long, traj_sim) = map( + self._make_trajectory, + (observations_real_short, observations_real_long, observations_sim), + ) + error_small = simple.calculate_observation_error( + real_trajectories=[traj_real_short], sim_trajectories=[traj_sim]) + error_big = simple.calculate_observation_error( + real_trajectories=[traj_real_long], sim_trajectories=[traj_sim]) + np.testing.assert_array_less(error_small, error_big) + + def test_observation_error_same_when_sim_trajectory_longer(self): + observations_real = np.array([[0], [1]]) + observations_sim_short = np.array([[1], [2]]) + observations_sim_long = np.array([[1], [2], [3]]) + (traj_real, traj_sim_short, traj_sim_long) = map( + self._make_trajectory, + (observations_real, observations_sim_short, observations_sim_long), + ) + error1 = simple.calculate_observation_error( + real_trajectories=[traj_real], sim_trajectories=[traj_sim_short]) + error2 = simple.calculate_observation_error( + real_trajectories=[traj_real], sim_trajectories=[traj_sim_long]) + np.testing.assert_array_almost_equal(error1, error2) + + def test_observation_error_reduces_over_trajectories(self): + observations1 = np.array([[1], [2], [3]]) + observations2 = np.array([[0], [2], [3]]) + (traj1, traj2) = map(self._make_trajectory, (observations1, observations2)) + error = simple.calculate_observation_error([traj1, traj1], [traj2, traj2]) + self.assertEqual(error.shape, (1,)) + + @staticmethod + @mock.patch.object(trainer_lib, "restore_state", autospec=True) + def _make_env( + mock_restore_state, observation_space, action_space, + max_trajectory_length, batch_size, + ): + # (model_params, opt_state) + mock_restore_state.return_value.params = (None, None) + + gin.bind_parameter("BoxSpaceSerializer.precision", 1) + + predict_output = (np.array([[[0.0]]] * batch_size)) + mock_model_fn = mock.MagicMock() + mock_model_fn.return_value.side_effect = itertools.repeat(predict_output) + mock_model_fn.return_value.initialize_once.return_value = ((), ()) + + return simulated_env_problem.SerializedSequenceSimulatedEnvProblem( + model=mock_model_fn, + reward_fn=(lambda _1, _2: np.zeros(batch_size)), + done_fn=(lambda _1, _2: np.full((batch_size,), False)), + vocab_size=1, + max_trajectory_length=max_trajectory_length, + batch_size=batch_size, + observation_space=observation_space, + action_space=action_space, + reward_range=(-1, 1), + discrete_rewards=False, + history_stream=itertools.repeat(None), + output_dir=None, + ) + + def test_evaluates_model_with_vector_observation_space(self): + with backend.use_backend("numpy"): + env = self._make_env( # pylint: disable=no-value-for-parameter + observation_space=gym.spaces.Box(shape=(2,), low=0, high=1), + action_space=gym.spaces.Discrete(n=1), + max_trajectory_length=2, + batch_size=3, + ) + trajectories = [ + self._make_trajectory(observations, actions) # pylint: disable=g-complex-comprehension + for (observations, actions) in [ + (np.array([[0, 1]]), np.array([0])), + (np.array([[1, 2], [3, 4]]), np.array([0, 0])), + (np.array([[1, 2], [3, 4], [5, 6]]), np.array([0, 0, 0])), + ] + ] + metrics = simple.evaluate_model(env, trajectories, plt) + self.assertIsNotNone(metrics) + self.assertEqual(len(metrics), 2) + + def test_fails_to_evaluate_model_with_matrix_observation_space(self): + with backend.use_backend("numpy"): + env = self._make_env( # pylint: disable=no-value-for-parameter + observation_space=gym.spaces.Box(shape=(2, 2), low=0, high=1), + action_space=gym.spaces.Discrete(n=1), + max_trajectory_length=2, + batch_size=1, + ) + trajectories = [ + self._make_trajectory(np.array([[0, 1], [2, 3]]), np.array([0]))] + metrics = simple.evaluate_model(env, trajectories, plt) + self.assertIsNone(metrics) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/simple_trainer.py b/trax/rl/simple_trainer.py new file mode 100644 index 000000000..624abeed7 --- /dev/null +++ b/trax/rl/simple_trainer.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SimPLe trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import itertools +import os +import random +import time + +from absl import logging +import gin +from matplotlib import pyplot as plt +from tensorflow.io import gfile +from trax import inputs as trax_inputs +from trax import jaxboard +from trax import trainer_lib +from trax.rl import base_trainer +from trax.rl import simple +from trax.rl import simulated_env_problem + + +class SimPLe(base_trainer.BaseTrainer): + """SimPLe trainer.""" + + def __init__(self, + train_env, + eval_env, + output_dir, + policy_trainer_class, + n_real_epochs=10, + data_eval_frac=0.125, + model_train_batch_size=64, + n_model_initial_train_steps=1000, + n_model_train_steps_per_epoch=1000, + simulated_env_problem_class=( + simulated_env_problem.SerializedSequenceSimulatedEnvProblem), + simulated_batch_size=16, + n_simulated_epochs=1000, + trajectory_dump_dir=None, + initial_trajectory_dir=None, + initial_trajectory_mix_prob=0.5, + initial_model=None, + init_policy_from_world_model=False, + **kwargs): + super(SimPLe, self).__init__(train_env, eval_env, output_dir, **kwargs) + self._policy_dir = os.path.join(output_dir, "policy") + self._model_dir = os.path.join(output_dir, "model") + # Initialize the policy trainer lazily, so in case of initializing the + # policy from world model checkpoint, the trainer will try to load the + # checkpoint _after_ it's been created in train_model(). + self._policy_trainer_fn = functools.partial( + policy_trainer_class, + train_env=train_env, + eval_env=eval_env, + output_dir=self._policy_dir, + async_mode=self._async_mode, + init_policy_from_world_model_output_dir=( + self._model_dir if init_policy_from_world_model else None + ), + ) + self._policy_trainer = None + self._n_real_epochs = n_real_epochs + self._model_train_batch_size = model_train_batch_size + self._n_model_initial_train_steps = n_model_initial_train_steps + self._n_model_train_steps_per_epoch = n_model_train_steps_per_epoch + self._data_eval_frac = data_eval_frac + + gfile.makedirs(self._model_dir) + if initial_model is not None: + gfile.copy( + initial_model, + os.path.join(self._model_dir, "model.pkl"), + overwrite=True, + ) + self._initial_model = initial_model + self._initial_trajectories = None + + self._sim_env = simulated_env_problem_class( + batch_size=None, + observation_space=train_env.observation_space, + action_space=train_env.action_space, + reward_range=train_env.reward_range, + discrete_rewards=train_env.discrete_rewards, + history_stream=None, # TODO(pkozakowski): Support this. + output_dir=self._model_dir, + ) + self._simulated_batch_size = simulated_batch_size + self._n_simulated_epochs = n_simulated_epochs + + # If trajectory_dump_dir is not provided explicitly, save the trajectories + # in output_dir. + if trajectory_dump_dir is None: + trajectory_dump_dir = os.path.join(output_dir, "trajectories") + self._trajectory_dump_root_dir = trajectory_dump_dir + + self._initial_trajectory_dir = initial_trajectory_dir + self._initial_trajectory_mix_prob = initial_trajectory_mix_prob + + self._summary_writer = jaxboard.SummaryWriter(self._output_dir) + + self._simple_epoch = 0 + self._policy_epoch = 0 + self._model_train_step = 0 + + @property + def policy_trainer(self): + if self._policy_trainer is None: + self._policy_trainer = self._policy_trainer_fn() + return self._policy_trainer + + @property + def epoch(self): + return self._simple_epoch + + def train_epoch(self, evaluate=True): + if self._simple_epoch > 0 or not self._has_initial_data: + logging.info( + "Collect trajectories by running the policy in the real environment.") + self.collect_trajectories(evaluate=evaluate) + if self._simple_epoch > 0 or not self._initial_model: + logging.info( + "Train the model of the environment on the collected trajectories.") + skipped = self.train_model() + if evaluate and not skipped: + logging.info("Evaluate the trained model.") + self.evaluate_model() + logging.info("Train the policy inside the simulated environment generated " + "by the model.") + self.train_policy() + + self._simple_epoch += 1 + + def evaluate(self): + self.policy_trainer.evaluate() + + def save(self): + # Nothing to do, as we save stuff continuously. + pass + + def flush_summaries(self): + self._summary_writer.flush() + + def collect_trajectories(self, evaluate): + logging.info("SimPLe epoch [% 6d]: collecting data.", self._simple_epoch) + start_time = time.time() + + self.policy_trainer.train_env = self.train_env + self.policy_trainer.trajectory_dump_dir = os.path.join( + self._trajectory_dump_root_dir, str(self.epoch)) + self._policy_epoch += self._n_real_epochs + self.policy_trainer.training_loop(self._policy_epoch, evaluate=evaluate) + + logging.vlog( + 1, "Collecting trajectories took %0.2f sec.", time.time() - start_time) + + def train_model(self): + """Train the model. + + Returns: + whether the training was skipped due to a restart. + """ + logging.info("SimPLe epoch [% 6d]: training model.", self._simple_epoch) + start_time = time.time() + + (train_stream, eval_stream) = self._make_input_streams() + # Ignore n_devices for now. + inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda + train_stream=(lambda: train_stream), + train_eval_stream=(lambda: train_stream), + eval_stream=(lambda: eval_stream), + input_shape=self._sim_env.model_input_shape, + input_dtype=self._sim_env.model_input_dtype, + # TODO(lukaszkaiser): correct those, they may differ from inputs. + target_shape=self._sim_env.model_input_shape, + target_dtype=self._sim_env.model_input_dtype) + + if self._simple_epoch == 0: + train_steps = self._n_model_initial_train_steps + else: + train_steps = self._n_model_train_steps_per_epoch + self._model_train_step += train_steps + with gin.config_scope("world_model"): + state = trainer_lib.train( + model=self._sim_env.model, + inputs=inputs, + train_steps=self._model_train_step, + output_dir=self._model_dir, + has_weights=True, + ) + + logging.vlog( + 1, "Training model took %0.2f sec.", time.time() - start_time) + return state.step > self._model_train_step + + def train_policy(self): + logging.info("SimPLe epoch [% 6d]: training policy.", self._simple_epoch) + start_time = time.time() + + self._sim_env.initialize( + batch_size=self._simulated_batch_size, + history_stream=itertools.repeat(None), + ) + # We never want async mode in the simulated env. + original_async_mode = self.policy_trainer.async_mode + self.policy_trainer.async_mode = False + self.policy_trainer.train_env = self._sim_env + # Don't dump trajectories from the simulated environment. + self.policy_trainer.trajectory_dump_dir = None + self._policy_epoch += self._n_simulated_epochs + self.policy_trainer.training_loop(self._policy_epoch, evaluate=False) + # Revert back to the original async mode in the policy trainer. + self.policy_trainer.async_mode = original_async_mode + + logging.vlog( + 1, "Training policy took %0.2f sec.", time.time() - start_time) + + @property + def _has_own_data(self): + return self._simple_epoch > 0 or self._initial_trajectory_dir is None + + @property + def _has_initial_data(self): + return self._initial_trajectory_dir is not None + + def _load_trajectories(self, initial): + # Cache the initial trajectories in memory, as loading them can take a lot + # of time and they don't change. + if initial: + if self._initial_trajectories is not None: + return self._initial_trajectories + trajectory_dir = self._initial_trajectory_dir + else: + trajectory_dir = self._trajectory_dump_root_dir + + trajectories = simple.load_trajectories( + trajectory_dir, self._data_eval_frac + ) + + if initial: + self._initial_trajectories = trajectories + return trajectories + + def _make_input_streams(self): + def make_example_streams(initial): + (train_trajs, eval_trajs) = self._load_trajectories(initial) + generate_examples = functools.partial( + simple.generate_examples, + trajectory_to_training_examples_fn=( + self._sim_env.trajectory_to_training_examples), + ) + return tuple(map(generate_examples, (train_trajs, eval_trajs))) + + # We mix two data sources: trajectories collected in this SimPLe training + # loop ("own" data) and trajectories collected before, outside of this + # training loop ("initial" data). + mix_prob = self._initial_trajectory_mix_prob + + if self._has_initial_data: + start_time = time.time() + # Load the initial, precollected data. + (init_train_stream, init_eval_stream) = make_example_streams(initial=True) + logging.vlog( + 1, "Loading initial trajectories took %0.2f sec.", + time.time() - start_time + ) + else: + (init_train_stream, init_eval_stream) = (None, None) + mix_prob = 0.0 # Take just our own collected data. + + if self._has_own_data: + start_time = time.time() + # Load trajectories collected in all epochs so far. + (own_train_stream, own_eval_stream) = make_example_streams(initial=False) + logging.vlog( + 1, "Loading own trajectories took %0.2f sec.", + time.time() - start_time + ) + else: + # We start the loop with training the model, so we don't have our own + # collected data yet. + (own_train_stream, own_eval_stream) = (None, None) + mix_prob = 1.0 # Take just the initial data. + + def mix_and_batch(streams): + (init_stream, own_stream) = streams + mixed_stream = simple.mix_streams(init_stream, own_stream, mix_prob) + return simple.batch_stream(mixed_stream, self._model_train_batch_size) + + return tuple( + map(mix_and_batch, ( + (init_train_stream, own_train_stream), + (init_eval_stream, own_eval_stream), + ))) + + def evaluate_model(self): + logging.info("SimPLe epoch [% 6d]: evaluating model.", self._simple_epoch) + start_time = time.time() + + self._sim_env.initialize( + batch_size=self._simulated_batch_size, + history_stream=itertools.repeat(None), + ) + + (_, eval_trajectories) = self._load_trajectories( + # If we have any trajectories collected in this run, evaluate on them. + # Otherwise, use the initial dataset. + initial=(not self._has_own_data) + ) + chosen_trajectories = [ + random.choice(eval_trajectories) + for _ in range(self._sim_env.batch_size) + ] + summaries = simple.evaluate_model(self._sim_env, chosen_trajectories, plt) + if summaries is not None: + for (name, value) in summaries.items(): + self._summary_writer.scalar( + "simple/{}".format(name), value, step=self._simple_epoch) + self._summary_writer.plot( + "simple/model_eval_plot", plt, step=self._simple_epoch) + self.flush_summaries() + + logging.vlog( + 1, "Evaluating model took %0.2f sec.", time.time() - start_time) diff --git a/trax/rl/simple_trainer_test.py b/trax/rl/simple_trainer_test.py new file mode 100644 index 000000000..5cf94e90b --- /dev/null +++ b/trax/rl/simple_trainer_test.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.simple_trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gin + +from tensor2tensor.envs import gym_env_problem +from tensor2tensor.rl import gym_utils +from tensorflow import test +from trax import models +from trax.rl import envs # pylint: disable=unused-import +from trax.rl import simulated_env_problem +from trax.rl import trainers + + +class SimpleTrainerTest(test.TestCase): + + def _make_wrapped_env(self, name, max_episode_steps=2): + wrapper_fn = functools.partial( + gym_utils.gym_env_wrapper, + **{ + "rl_env_max_episode_steps": max_episode_steps, + "maxskip_env": False, + "rendered_env": False, + "rendered_env_resize_to": None, # Do not resize frames + "sticky_actions": False, + "output_dtype": None, + }) + + return gym_env_problem.GymEnvProblem(base_env_name=name, + batch_size=2, + env_wrapper_fn=wrapper_fn, + discrete_rewards=False) + + def test_training_loop_acrobot(self): + gin.bind_parameter("BoxSpaceSerializer.precision", 2) + gin.bind_parameter("trainer_lib.train.eval_steps", 1) + trainer = trainers.SimPLe( + train_env=self._make_wrapped_env("Acrobot-v1"), + eval_env=self._make_wrapped_env("Acrobot-v1"), + output_dir=self.get_temp_dir(), + policy_trainer_class=functools.partial( + trainers.PPO, + policy_and_value_model=functools.partial( + models.FrameStackMLP, + n_frames=1, + hidden_sizes=(), + output_size=1, + ), + n_optimizer_steps=1, + ), + n_real_epochs=1, + data_eval_frac=0.5, + model_train_batch_size=2, + n_model_initial_train_steps=1, + n_model_train_steps_per_epoch=1, + simulated_env_problem_class=functools.partial( + simulated_env_problem.SerializedSequenceSimulatedEnvProblem, + model=functools.partial( + models.TransformerLM, + d_model=2, + n_layers=0, + max_len=64, + ), + reward_fn=simulated_env_problem.acrobot_reward_fn, + done_fn=simulated_env_problem.acrobot_done_fn, + vocab_size=4, + max_trajectory_length=4, + ), + simulated_batch_size=2, + n_simulated_epochs=1, + ) + trainer.training_loop(n_epochs=1) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/simulated_env_problem.py b/trax/rl/simulated_env_problem.py new file mode 100644 index 000000000..cfac934c8 --- /dev/null +++ b/trax/rl/simulated_env_problem.py @@ -0,0 +1,499 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EnvProblem for environments simulated by a Trax model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import random + +import numpy as np + +from tensor2tensor.envs import env_problem +from trax import backend +from trax import trainer_lib +from trax import utils +from trax.backend import random as jax_random +from trax.rl import serialization_utils +from trax.rl import space_serializer + + +class SimulatedEnvProblem(env_problem.EnvProblem): + """EnvProblem base class for environments simulated by Trax models. + + The initial observations to start the model are taken from + initial_observation_stream. This iterator in incremented in every reset(). + + A checkpoint saved by the Trax trainer should be available in output_dir. + """ + + def __init__(self, model, batch_size, observation_space, action_space, + reward_range, discrete_rewards, history_stream, output_dir, + model_predict_kwargs=None): + """Initializes the env. + + Args: + model: Trax model. + batch_size: (int) Number of simulated environments run in parallel. + observation_space: (gym.Space) Observation space. + action_space: (gym.Space) Action space. + reward_range: (tuple) Pair (min_reward, max_reward). + discrete_rewards: (bool) Whether to discretize the rewards. + history_stream: Iterator yielding batches of initial input data for the + model. The format is implementation-specific. + output_dir: (str) Output dir. + model_predict_kwargs: (dict) Additional model keyword arguments for + inference. Useful when different config is needed for training and + inference, e.g. train with memory efficient attention and predict with + the regular one. + """ + self._model = model + if model_predict_kwargs is None: + model_predict_kwargs = {} + model_predict = self._model(mode="predict", **model_predict_kwargs) + def predict_with_state(*args, **kwargs): + output = model_predict(*args, **kwargs) + return (output, model_predict.state) + self._model_predict = backend.jit(predict_with_state) + self._model_initialize = model_predict.initialize_once + + self._observation_space = observation_space + self._action_space = action_space + self._reward_range = reward_range + self._output_dir = output_dir + + self._predict_fn = None + self._rng = None + self._model_state = None + self._history_stream = None + + # Call the super's ctor. It will use some of the member fields, so we call + # it in the end. + super(SimulatedEnvProblem, self).__init__( + batch_size=batch_size, + discrete_rewards=discrete_rewards, + history_stream=history_stream, + ) + + self.seed() + + def initialize_environments(self, + history_stream, + batch_size=1, + parallelism=1): + """Initializes the environments. + + Args: + history_stream: Iterator yielding batches of initial input data for the + model. The format is implementation-specific. + batch_size: (int) Number of environments in a batch. + parallelism: (int) Unused. + """ + del parallelism + + trax_state = trainer_lib.restore_state(self._output_dir) + # TODO(lukaszkaiser): both model state and parameters by default include + # the loss layer. Currently, we access the pure-model parameters by just + # indexing, [0] here. But we should make it more explicit in a better API. + model_params = trax_state.opt_state.params[0] + self._model_state = trax_state.model_state[0] + + def predict_fn(inputs, rng): + (output, self._model_state) = self._model_predict( + inputs, params=model_params, state=self._model_state, rng=rng + ) + return output + + self._predict_fn = predict_fn + self._history_stream = history_stream + self._steps = np.zeros(batch_size, dtype=np.int32) + + @property + def observation_space(self): + return self._observation_space + + @property + def action_space(self): + return self._action_space + + @property + def reward_range(self): + return self._reward_range + + def seed(self, seed=None): + if seed is None: + seed = random.randint(0, 2**31 - 1) + self._rng = jax_random.get_prng(seed) + return super(SimulatedEnvProblem, self).seed(seed=seed) + + def _reset_model(self, predict_fn, indices, history, rng): + """Resets the environments at the given indices. + + Should be implemented in subclasses. + + Args: + predict_fn: Function running prediction with the model. + indices: List of indices of underlying envs to call reset on. + history: Initial input data for the model. + rng: Jax RNG. + + Returns: + np.ndarray of batched observations from the reset envs. + """ + raise NotImplementedError + + def _step_model(self, predict_fn, actions, rng): + """Takes a step in all environments. + + Should be implemented in subclasses. + + Args: + predict_fn: Function running prediction with the model. + actions: (np.ndarray) with first dimension equal to the batch size. + rng: Jax RNG. + + Returns: + a tuple of batched raw observations, rewards and dones. + """ + raise NotImplementedError + + def trajectory_to_training_examples(self, trajectory): + raise NotImplementedError + + @property + def model_input_shape(self): + raise NotImplementedError + + @property + def model_input_dtype(self): + raise NotImplementedError + + def _reset(self, indices): + """Resets environments at the given indices. + + Args: + indices: list of indices of underlying envs to call reset on. + + Returns: + np.ndarray of batched observations from the reset envs. + """ + history = next(self._history_stream) + (subrng, self._rng) = jax_random.split(self._rng) + return self._reset_model(self._predict_fn, indices, history, subrng) + + def _step(self, actions): + """Takes a step in all environments. + + Args: + actions: (np.ndarray) with first dimension equal to the batch size. + + Returns: + a tuple of batched raw observations, raw rewards, dones and infos. + """ + # Predict the next observation. + (subrng, self._rng) = jax_random.split(self._rng) + (observation, reward, done) = self._step_model( + self._predict_fn, actions, subrng) + return (observation, reward, done, {}) + + @property + def model(self): + return self._model + + +class RawSimulatedEnvProblem(SimulatedEnvProblem): + """SimulatedEnvProblem running a model operating on raw tensors. + + Wraps an autoregressive trax model of signature + (observation_history, action) -> (observation, reward) in an EnvProblem. + The model is assumed to take a fixed number of last observations as input + and produce a single observation, which is fed back into the model in the + next environment step. + + Shape requirements (without the batch dimension): + observation: Consistent with observation_space. + observation_history: (history_length,) + observation.shape. + action: Consistent with action_space. + reward: (1,). The singleton dimension is removed in step(). + """ + + def __init__(self, history_length, trajectory_length, *args, **kwargs): + """Initializes the env. + + Args: + history_length: (int) Number of last observations fed into the model. + trajectory_length: (int) Length of each trajectory unrolled from the + model. + *args: (tuple) Positional arguments passed to the base class. + **kwargs: (dict) Keyword arguments passed to the base class. + """ + self._history_length = history_length + self._trajectory_length = trajectory_length + self._history = None + self._steps = None + + super(RawSimulatedEnvProblem, self).__init__(*args, **kwargs) + + def initialize_environments(self, batch_size=1, **kwargs): + """Initializes the environments.""" + self._history = None + self._steps = np.zeros(batch_size) + return super(RawSimulatedEnvProblem, self).initialize_environments( + batch_size=batch_size, **kwargs) + + def _reset_model(self, predict_fn, indices, history, rng): + del predict_fn + del rng + assert history.shape == ((self._batch_size, self._history_length) + + self.observation_space.shape) + + if self._history is None: + # At the first reset, all indices should be triggered. + assert set(indices) == set(range(self._batch_size)) + self._history = np.array(history) + else: + history = history[indices, ...] + self._history[indices, ...] = history + + # Reset the step counters. + self._steps[indices] = 0 + + # Return just the last timestep at the given indices. + return history[:, -1, ...] + + def _step_model(self, predict_fn, actions, rng): + (observation, reward) = predict_fn((self._history, actions), rng=rng) + + # Roll the history one timestep back and append the new observation. + self._history = np.roll(self._history, shift=-1, axis=1) + self._history[:, -1, ...] = observation + + # Increment the step counters and determine which envs are done. + self._steps += 1 + done = self._steps == self._trajectory_length + + # Call copy() to get the data as numpy arrays. + observation = observation.copy() + # Reshape the rewards to get rid of the extra dimension. + reward = np.squeeze(reward.copy(), axis=1) + return (observation, reward, done) + + +class SerializedSequenceSimulatedEnvProblem(SimulatedEnvProblem): + """SimulatedEnvProblem running a model operating on sequences of symbols. + + Wraps an autoregressive trax model of signature past_symbols -> symbol_probs + in an EnvProblem. The model is assumed to take a sequence of symbols as input + and produce distributions over all symbols in the sequence. The next symbol + is sampled and fed back to the model in the next decoding step. + + Shape requirements (without the batch dimension): + past_symbols: (max_trajectory_length * L,) + symbol_probs: (max_trajectory_length * L, vocab_size) + where L is the representation length of one environment step. + + Observations, actions, rewards and done flags are (de)serialized from/to + sequences of symbols using an EnvSerializer passed to the constructor. + """ + + def __init__(self, model, reward_fn, done_fn, vocab_size, + max_trajectory_length, observation_space, action_space, + significance_decay=1.0, **kwargs): + """Initializes the env. + + Args: + model: trax model to use for simulation. It's assumed to take keyword + arguments vocab_size and mode, where vocab_size is the number of symbols + in the vocabulary and mode is either "train" or "eval". + + reward_fn: Function (previous_observation, current_observation) -> reward. + done_fn: Function (previous_observation, current_observation) -> done. + vocab_size: (int) Number of symbols in the vocabulary. + max_trajectory_length: (int) Maximum length of a trajectory unrolled from + the model. + observation_space: (gym.Space) Observation space. + action_space: (gym.Space) Action space. + significance_decay: (float) Decay for training weights of progressively + less significant symbols in the representation. + **kwargs: (dict) Keyword arguments passed to the base class. + """ + self._reward_fn = reward_fn + self._done_fn = done_fn + self._vocab_size = vocab_size + self._max_trajectory_length = max_trajectory_length + self._significance_decay = significance_decay + self._steps = None + self._observation_space = None + self._action_space = None + self._last_observations = None + + self._obs_serializer = space_serializer.create( + observation_space, self._vocab_size) + self._action_serializer = space_serializer.create( + action_space, self._vocab_size) + self._obs_repr_length = self._obs_serializer.representation_length + self._act_repr_length = self._action_serializer.representation_length + self._step_repr_length = self._obs_repr_length + self._act_repr_length + + # We assume that the model takes vocab_size as an argument (e.g. + # TransformerLM). + model = functools.partial(model, vocab_size=vocab_size) + super(SerializedSequenceSimulatedEnvProblem, self).__init__( + model=model, + observation_space=observation_space, + action_space=action_space, + **kwargs + ) + + def initialize_environments(self, batch_size=1, **kwargs): + """Initializes the environments.""" + self._steps = np.zeros(batch_size, dtype=np.int32) + self._last_observations = np.full( + (batch_size,) + self._observation_space.shape, np.nan) + self._last_symbols = np.zeros((batch_size, 1), dtype=np.int32) + super(SerializedSequenceSimulatedEnvProblem, self).initialize_environments( + batch_size=batch_size, **kwargs) + (subrng, self._rng) = jax_random.split(self._rng) + (_, self._init_model_state) = self._model_initialize( + input_shapes=(batch_size, 1), input_dtype=np.int32, rng=subrng + ) + + def _predict_obs(self, predict_fn, rng): + obs_repr = np.zeros( + (self._steps.shape[0], self._obs_repr_length), dtype=np.int32, + ) + for (i, subrng) in enumerate(jax_random.split(rng, self._obs_repr_length)): + log_probs = predict_fn(self._last_symbols, rng=subrng) + self._last_symbols = utils.gumbel_sample(log_probs) + obs_repr[:, i] = self._last_symbols[:, 0] + return self._obs_serializer.deserialize(obs_repr) + + def _consume_act(self, actions, predict_fn, rng): + act_repr = self._action_serializer.serialize(actions) + for (i, subrng) in enumerate(jax_random.split(rng, self._act_repr_length)): + # Run the network to update the inference buffers, but ignore the result. + predict_fn(self._last_symbols, rng=subrng) + self._last_symbols = act_repr[:, i:(i + 1)] + + def _reset_model(self, predict_fn, indices, history, rng): + # TODO(pkozakowski): Random starts. + del history + + indices = np.array(indices) + assert indices.shape[0] in (0, self._steps.shape[0]), ( + # TODO(pkozakowski): Lift this requirement. + "Only resetting all envs at once is supported." + ) + + self._model_state = self._init_model_state + self._last_symbols[indices] = 0 + self._steps[indices] = 0 + observation = self._predict_obs(predict_fn, rng)[indices] + self._last_observations[indices] = observation + return observation + + def _step_model(self, predict_fn, actions, rng): + self._consume_act(actions, predict_fn, rng) + self._steps += 1 + observation = self._predict_obs(predict_fn, rng) + reward = self._reward_fn(self._last_observations, observation) + done = self._done_fn(self._last_observations, observation) + # Copy the last observations, so that we don't overwrite data stored in a + # trajectory when resetting the environment (see _reset_model). + self._last_observations = np.copy(observation) + done = np.logical_or(done, self._steps == self._max_trajectory_length - 1) + return (observation, reward, done) + + def trajectory_to_training_examples(self, trajectory): + (repr_length,) = self.model_input_shape + seq_mask = np.ones((1, trajectory.num_time_steps - 1)) + (reprs, repr_mask) = serialization_utils.serialize_observations_and_actions( + # Serialization works on batches, so we add a singleton batch dimension. + trajectory.observations_np[None, ...], + trajectory.actions_np[None, ...], + seq_mask, + self._obs_serializer, + self._action_serializer, + repr_length, + ) + reprs = reprs[0, ...].astype(self.model_input_dtype) + sig_weights = ( + self._significance_decay ** serialization_utils.significance_map( + self._obs_serializer, self._action_serializer, repr_length + )[None, ...] + ) + obs_mask = serialization_utils.observation_mask( + self._obs_serializer, self._action_serializer, repr_length + ) + weights = (sig_weights * obs_mask * repr_mask)[0, ...] + # (inputs, targets, weights) + return [(reprs, reprs, weights)] + + @property + def model_input_shape(self): + return (self._max_trajectory_length * self._step_repr_length,) + + @property + def model_input_dtype(self): + return np.int32 + + +def cartpole_done_fn(previous_observation, current_observation): + del previous_observation + x_threshold = 2.4 + theta_threshold = 12 * 2 * np.pi / 360 + x = current_observation[:, 0] + theta = current_observation[:, 2] + return np.logical_or(np.abs(x) > x_threshold, np.abs(theta) > theta_threshold) + + +def cartpole_reward_fn(previous_observation, current_observation): + done = cartpole_done_fn(previous_observation, current_observation) + return 1.0 - done # Unit reward for every timestep until the end. + + +def acrobot_done_fn(previous_observation, current_observation): + del previous_observation + theta1 = current_observation[:, 0] + theta2 = current_observation[:, 1] + return -np.cos(theta1) - np.cos(theta2 + theta1) > 1.0 + + +def acrobot_reward_fn(previous_observation, current_observation): + done = acrobot_done_fn(previous_observation, current_observation) + return -1.0 + done # -1 reward for every timestep until the end. + + +def onlinetune_done_fn(previous_observation, current_observation): + del previous_observation + del current_observation + # Never return "done" from the environment, rely on max_trajectory_length + # instead. + return False + + +def onlinetune_reward_fn( + previous_observation, + current_observation, + # 2 is the evaluation accuracy metric in the default settings of + # OnlineTuneEnv. + dim_index=2, +): + prev = previous_observation[:, dim_index] + cur = current_observation[:, dim_index] + return cur - prev diff --git a/trax/rl/simulated_env_problem_test.py b/trax/rl/simulated_env_problem_test.py new file mode 100644 index 000000000..4f00acee0 --- /dev/null +++ b/trax/rl/simulated_env_problem_test.py @@ -0,0 +1,292 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.simulated_env_problem.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import gin +import gym +import mock +import numpy as np + +from tensor2tensor.envs import trajectory +from tensorflow import test +from trax import backend +from trax import trainer_lib +from trax.rl import simulated_env_problem + + +class RawSimulatedEnvProblemTest(test.TestCase): + + @staticmethod + @mock.patch.object(trainer_lib, "restore_state", autospec=True) + def _create_env(mock_restore_state, model, histories, + trajectory_length): + # (model_params, opt_state) + mock_restore_state.return_value.params = (None, None) + space = gym.spaces.Discrete(100) + return simulated_env_problem.RawSimulatedEnvProblem( + model=model, + history_length=histories.shape[2], + trajectory_length=trajectory_length, + batch_size=1, + observation_space=space, + action_space=space, + reward_range=(-1, 1), + discrete_rewards=True, + history_stream=iter(histories), + output_dir=None, + ) + + def test_communicates_with_model(self): + # Mock model increasing the observation by action, reward is the parity of + # the new observation. + def mock_transition(inputs, *args, **kwargs): + del args + del kwargs + (observations, actions) = inputs + new_observations = observations[:, -1] + actions + rewards = np.array([[int(new_observations % 2 == 0)]]) + return (new_observations, rewards) + + mock_model_fn = mock.MagicMock() + mock_model_fn.return_value.side_effect = mock_transition + mock_model = mock_model_fn.return_value + + actions_to_take = np.array([[1], [3]]) + histories = np.array([[[0, 1, 2, 3]]]) + expected_observations = np.array([[3], [4], [7]]) + expected_rewards = np.array([[1], [0]]) + expected_dones = np.array([[False], [True]]) + expected_histories = np.array([[[0, 1, 2, 3]], [[1, 2, 3, 4]]]) + expected_actions = actions_to_take + + with backend.use_backend("numpy"): + env = self._create_env( # pylint: disable=no-value-for-parameter + model=mock_model_fn, + histories=histories, + trajectory_length=len(actions_to_take), + ) + actual_observations = [env.reset()] + actual_rewards = [] + actual_dones = [] + actual_histories = [] + actual_actions = [] + for action in actions_to_take: + (observation, reward, done, _) = env.step(action) + actual_observations.append(observation) + actual_rewards.append(reward) + actual_dones.append(done) + # Mock call is a tuple (args, kwargs). There is one positional argument, + # which is a tuple (history, action). + (((history, action),), _) = mock_model.call_args + actual_actions.append(action) + actual_histories.append(history) + + np.testing.assert_array_equal(actual_observations, expected_observations) + np.testing.assert_array_equal(actual_rewards, expected_rewards) + np.testing.assert_array_equal(actual_dones, expected_dones) + np.testing.assert_array_equal(actual_histories, expected_histories) + np.testing.assert_array_equal(actual_actions, expected_actions) + + def test_takes_new_history(self): + histories = np.array([[[0, 1, 2]], [[3, 4, 5]]]) + + with backend.use_backend("numpy"): + env = self._create_env( # pylint: disable=no-value-for-parameter + model=mock.MagicMock(), + histories=histories, + trajectory_length=2, + ) + env.reset() + observation = env.reset() + np.testing.assert_array_equal(observation, [5]) + + +class SerializedSequenceSimulatedEnvProblemTest(test.TestCase): + + def _make_env( + self, observation_space, action_space, vocab_size, + predict_fn=None, reward_fn=None, done_fn=None, + batch_size=None, max_trajectory_length=None, + ): + mock_model_fn = mock.MagicMock() + if predict_fn is not None: + mock_model_fn.return_value = predict_fn + mock_model_fn.return_value.initialize_once.return_value = ((), ()) + return simulated_env_problem.SerializedSequenceSimulatedEnvProblem( + model=mock_model_fn, + reward_fn=reward_fn, + done_fn=done_fn, + vocab_size=vocab_size, + max_trajectory_length=max_trajectory_length, + batch_size=batch_size, + observation_space=observation_space, + action_space=action_space, + reward_range=(-1, 1), + discrete_rewards=False, + history_stream=itertools.repeat(None), + output_dir=None, + ) + + def _make_trajectory(self, observations, actions): + assert len(observations) == len(actions) + 1 + t = trajectory.Trajectory() + for (obs, act) in zip(observations, actions): + t.add_time_step(observation=obs, action=act, done=False) + t.add_time_step(observation=observations[-1], done=True) + return t + + @mock.patch.object(trainer_lib, "restore_state", autospec=True) + def test_communicates_with_model(self, mock_restore_state): + gin.bind_parameter("BoxSpaceSerializer.precision", 1) + vocab_size = 16 + # Mock model predicting a fixed sequence of symbols. It is made such that + # the first two observations are different and the last one is equal to the + # first. + symbols = [ + 1, 1, 2, 2, 0, 0, # obs1 act1 + 1, 2, 2, 1, 0, 0, # obs2 act2 + 1, 1, 2, 2, # obs3 + ] + def make_prediction(symbol): + one_hot = np.eye(vocab_size)[symbol] + log_probs = (1 - one_hot) * -100.0 # Virtually deterministic. + # (4 obs symbols + 1 action symbol) * 3 timesteps = 15. + return np.array([[log_probs]]) + + mock_predict_fn = mock.MagicMock() + mock_predict_fn.side_effect = map(make_prediction, symbols) + + with backend.use_backend("numpy"): + # (model_params, opt_state) + mock_restore_state.return_value.params = (None, None) + env = self._make_env( + predict_fn=mock_predict_fn, + reward_fn=(lambda _1, _2: np.array([0.5])), + done_fn=(lambda _1, _2: np.array([False])), + vocab_size=vocab_size, + batch_size=1, + max_trajectory_length=3, + observation_space=gym.spaces.Box(low=0, high=5, shape=(4,)), + action_space=gym.spaces.MultiDiscrete(nvec=[2, 2]), + ) + + def assert_input_suffix(expected_symbols): + actual_symbols = np.array([ + symbol.item() for ((symbol,), _) in mock_predict_fn.call_args_list[ + -len(expected_symbols): + ] + ]) + np.testing.assert_array_equal(actual_symbols, expected_symbols) + + actions = [[0, 1], [1, 0]] + + obs1 = env.reset() + assert_input_suffix(symbols[:3]) + + (obs2, reward, done, _) = env.step(np.array([actions[0]])) + # Symbols going into the decoder when predicting the next observation are: + # the last symbol of the previous observation, all action symbols, all + # symbols but the last one of the next observation. + assert_input_suffix([symbols[3]] + actions[0] + symbols[6:9]) + self.assertFalse(np.array_equal(obs1, obs2)) + np.testing.assert_array_equal(reward, [0.5]) + np.testing.assert_array_equal(done, [False]) + + (obs3, reward, done, _) = env.step(np.array([actions[1]])) + assert_input_suffix([symbols[9]] + actions[1] + symbols[12:15]) + np.testing.assert_array_equal(obs1, obs3) + np.testing.assert_array_equal(reward, [0.5]) + np.testing.assert_array_equal(done, [True]) + + def test_makes_training_example(self): + env = self._make_env( + vocab_size=2, + observation_space=gym.spaces.Discrete(2), + action_space=gym.spaces.Discrete(2), + max_trajectory_length=3, + ) + t = self._make_trajectory(observations=[0, 1, 0], actions=[1, 0]) + examples = env.trajectory_to_training_examples(t) + + # There should be 1 example with the whole trajectory. + self.assertEqual(len(examples), 1) + [(inputs, targets, weights)] = examples + # inputs == targets for autoregressive sequence prediction. + np.testing.assert_array_equal(inputs, targets) + # Assert array shapes and datatypes. + self.assertEqual(inputs.shape, env.model_input_shape) + self.assertEqual(inputs.dtype, env.model_input_dtype) + self.assertEqual(weights.shape, env.model_input_shape) + # Actions should be masked out. + self.assertEqual(np.min(weights), 0) + # At least part of the observation should have full weight. + self.assertEqual(np.max(weights), 1) + + def test_makes_training_examples_from_trajectories_of_different_lengths(self): + env = self._make_env( + vocab_size=2, + observation_space=gym.spaces.Discrete(2), + action_space=gym.spaces.Discrete(2), + max_trajectory_length=3, + ) + t1 = self._make_trajectory(observations=[0, 1], actions=[1]) + [(x1, _, w1)] = env.trajectory_to_training_examples(t1) + t2 = self._make_trajectory(observations=[0, 1, 0], actions=[1, 0]) + [(x2, _, w2)] = env.trajectory_to_training_examples(t2) + + # Examples should be padded to the same shape. + self.assertEqual(x1.shape, x2.shape) + self.assertEqual(w1.shape, w2.shape) + # Cumulative weight should increase with trajectory length. + self.assertGreater(np.sum(w2), np.sum(w1)) + + def test_masked_representation_changes_with_observation(self): + env = self._make_env( + vocab_size=2, + observation_space=gym.spaces.Discrete(2), + action_space=gym.spaces.Discrete(2), + max_trajectory_length=3, + ) + t1 = self._make_trajectory(observations=[0, 1], actions=[1]) + [(x1, _, w1)] = env.trajectory_to_training_examples(t1) + t2 = self._make_trajectory(observations=[0, 0], actions=[1]) + [(x2, _, w2)] = env.trajectory_to_training_examples(t2) + + self.assertFalse(np.array_equal(x1 * w1, x2 * w2)) + + def test_masked_representation_doesnt_change_with_action(self): + env = self._make_env( + vocab_size=2, + observation_space=gym.spaces.Discrete(2), + action_space=gym.spaces.Discrete(2), + max_trajectory_length=3, + ) + t1 = self._make_trajectory(observations=[0, 1], actions=[1]) + [(x1, _, w1)] = env.trajectory_to_training_examples(t1) + t2 = self._make_trajectory(observations=[0, 1], actions=[0]) + [(x2, _, w2)] = env.trajectory_to_training_examples(t2) + + np.testing.assert_array_equal(x1 * w1, x2 * w2) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/space_serializer.py b/trax/rl/space_serializer.py new file mode 100644 index 000000000..036735303 --- /dev/null +++ b/trax/rl/space_serializer.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization of elements of Gym spaces into discrete sequences.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from absl import logging +import gin +import gym +import numpy as np + + +class SpaceSerializer(object): + """Base class for Gym space serializers. + + Attrs: + space_type: (type) Gym space class that this SpaceSerializer corresponds + to. Should be defined in subclasses. + representation_length: (int) Number of symbols in the representation of + every element of the space. + significance_map: (np.ndarray) Integer array of the same size as the + discrete representation, where elements describe the significance of + symbols, e.g. in fixed-precision encoding. 0 is the most significant + symbol, 1 the second most significant etc. + """ + + space_type = None + representation_length = None + significance_map = None + + def __init__(self, space, vocab_size): + """Creates a SpaceSerializer. + + Subclasses should retain the signature. + + Args: + space: (gym.Space) Gym space of type self.space_type. + vocab_size: (int) Number of symbols in the vocabulary. + """ + assert isinstance(space, self.space_type) + self._space = space + self._vocab_size = vocab_size + + def serialize(self, data): + """Serializes a batch of space elements into discrete sequences. + + Should be defined in subclasses. + + Args: + data: A batch of batch_size elements of the Gym space to be serialized. + + Returns: + int32 array of shape (batch_size, self.representation_length). + """ + raise NotImplementedError + + def deserialize(self, representation): + """Deserializes a batch of discrete sequences into space elements. + + Should be defined in subclasses. + + Args: + representation: int32 Numpy array of shape + (batch_size, self.representation_length) to be deserialized. + + Returns: + A batch of batch_size deserialized elements of the Gym space. + """ + raise NotImplementedError + + +def create(space, vocab_size): + """Creates a SpaceSerializer for the given Gym space.""" + return { + gym.spaces.Box: BoxSpaceSerializer, + gym.spaces.Discrete: DiscreteSpaceSerializer, + gym.spaces.MultiDiscrete: MultiDiscreteSpaceSerializer, + }[type(space)](space, vocab_size) + + +@gin.configurable(blacklist=["space", "vocab_size"]) +class BoxSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.Box. + + Assumes that the space is bounded. Internally rescales it to the [0, 1] + interval and uses a fixed-precision encoding. + """ + + space_type = gym.spaces.Box + + def __init__(self, space, vocab_size, precision=2, max_range=(-100.0, 100.0)): + self._precision = precision + + # Some gym envs (e.g. CartPole) have unreasonably high bounds for + # observations. We clip so we can represent them. + bounded_space = copy.copy(space) + (min_low, max_high) = max_range + bounded_space.low = np.maximum(space.low, min_low) + bounded_space.high = np.minimum(space.high, max_high) + if (not np.allclose(bounded_space.low, space.low) or + not np.allclose(bounded_space.high, space.high)): + logging.warning( + "Space limits %s, %s out of bounds %s. Clipping to %s, %s.", + str(space.low), str(space.high), str(max_range), + str(bounded_space.low), str(bounded_space.high) + ) + + super(BoxSpaceSerializer, self).__init__(bounded_space, vocab_size) + + def serialize(self, data): + array = data + batch_size = array.shape[0] + array = (array - self._space.low) / (self._space.high - self._space.low) + digits = [] + for digit_index in range(-1, -self._precision - 1, -1): + threshold = self._vocab_size ** digit_index + digit = np.array(array / threshold).astype(np.int32) + # For the corner case of x == high. + digit[digit == self._vocab_size] -= 1 + digits.append(digit) + array -= digit * threshold + digits = np.stack(digits, axis=-1) + return np.reshape(digits, (batch_size, -1)) + + def deserialize(self, representation): + digits = representation + batch_size = digits.shape[0] + digits = np.reshape(digits, (batch_size, -1, self._precision)) + array = np.zeros(digits.shape[:-1]) + for digit_index_in_seq in range(self._precision): + digit_index = -digit_index_in_seq - 1 + array += self._vocab_size ** digit_index * digits[..., digit_index_in_seq] + array = np.reshape(array, (batch_size,) + self._space.shape) + return array * (self._space.high - self._space.low) + self._space.low + + @property + def representation_length(self): + return self._precision * self._space.low.size + + @property + def significance_map(self): + return np.reshape(np.broadcast_to( + np.arange(self._precision), self._space.shape + (self._precision,)), -1) + + +class DiscreteSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.Discrete. + + Assumes that the size of the space fits in the number of symbols. + """ + + space_type = gym.spaces.Discrete + representation_length = 1 + + def __init__(self, space, vocab_size): + super(DiscreteSpaceSerializer, self).__init__(space, vocab_size) + assert space.n <= vocab_size, ( + "Discrete space size should fit in the number of symbols.") + + def serialize(self, data): + return np.reshape(data, (-1, 1)).astype(np.int32) + + def deserialize(self, representation): + return np.reshape(representation, -1) + + @property + def significance_map(self): + return np.zeros(1, dtype=np.int32) + + +class MultiDiscreteSpaceSerializer(SpaceSerializer): + """Serializer for gym.spaces.MultiDiscrete. + + Assumes that the number of categories in each dimension fits in the number of + symbols. + """ + + space_type = gym.spaces.MultiDiscrete + + def __init__(self, space, vocab_size): + super(MultiDiscreteSpaceSerializer, self).__init__(space, vocab_size) + assert np.max(space.nvec) <= vocab_size, ( + "MultiDiscrete maximum number of categories should fit in the number " + "of symbols." + ) + + def serialize(self, data): + return data.astype(np.int32) + + def deserialize(self, representation): + return representation + + @property + def representation_length(self): + return len(self._space.nvec) + + @property + def significance_map(self): + return np.zeros(self.representation_length, dtype=np.int32) diff --git a/trax/rl/space_serializer_test.py b/trax/rl/space_serializer_test.py new file mode 100644 index 000000000..46751215a --- /dev/null +++ b/trax/rl/space_serializer_test.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for trax.rl.space_serializer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin +import gym +import numpy as np +from tensorflow import test +from trax.rl import space_serializer + + +class BoxSpaceSerializerTest(test.TestCase): + + def _make_space_and_serializer( + self, low=-10, high=10, shape=(2,), + # Weird vocab_size to test that it doesn't only work with powers of 2. + vocab_size=257, + # Enough precision to represent float32s accurately. + precision=4, + ): + gin.bind_parameter("BoxSpaceSerializer.precision", precision) + space = gym.spaces.Box(low=low, high=high, shape=shape) + serializer = space_serializer.create(space, vocab_size=vocab_size) + return (space, serializer) + + def _sample_batch(self, space): + return np.reshape(space.sample(), (1,) + space.shape) + + def test_representation_length(self): + (space, serializer) = self._make_space_and_serializer() + input_array = self._sample_batch(space) + representation = serializer.serialize(input_array) + self.assertEqual( + representation.shape, (1, serializer.representation_length)) + + def test_commutes(self): + (space, serializer) = self._make_space_and_serializer() + input_array = self._sample_batch(space) + representation = serializer.serialize(input_array) + output_array = serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_representation_changes(self): + (space, serializer) = self._make_space_and_serializer() + array1 = self._sample_batch(space) + array2 = -array1 + (repr1, repr2) = tuple(map(serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_bounds_space(self): + gin.bind_parameter("BoxSpaceSerializer.max_range", (-10.0, 10.0)) + (_, serializer) = self._make_space_and_serializer( + # Too wide range to represent, need to clip. + low=-1e18, high=1e18, + shape=(1,)) + input_array = np.array([[1.2345]]) + representation = serializer.serialize(input_array) + output_array = serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_significance_map(self): + (_, serializer) = self._make_space_and_serializer(shape=(2,)) + np.testing.assert_array_equal( + serializer.significance_map, [0, 1, 2, 3, 0, 1, 2, 3]) + + def test_serializes_boundaries(self): + vocab_size = 256 + precision = 4 + (_, serializer) = self._make_space_and_serializer( + low=-1, high=1, shape=(1,), vocab_size=vocab_size, precision=precision, + ) + input_array = np.array([[-1, 1]]) + representation = serializer.serialize(input_array) + np.testing.assert_array_equal( + representation, [[0] * precision + [vocab_size - 1] * precision] + ) + + +class DiscreteSpaceSerializerTest(test.TestCase): + + def setUp(self): + super(DiscreteSpaceSerializerTest, self).setUp() + self._space = gym.spaces.Discrete(n=2) + self._serializer = space_serializer.create(self._space, vocab_size=2) + + def _sample_batch(self): + return np.reshape(self._space.sample(), (1,) + self._space.shape) + + def test_representation_length(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + self.assertEqual( + representation.shape, (1, self._serializer.representation_length)) + + def test_commutes(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + output_array = self._serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_representation_changes(self): + array1 = self._sample_batch() + array2 = 1 - array1 + (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_significance_map(self): + np.testing.assert_array_equal(self._serializer.significance_map, [0]) + + +class MultiDiscreteSpaceSerializerTest(test.TestCase): + + def setUp(self): + super(MultiDiscreteSpaceSerializerTest, self).setUp() + self._space = gym.spaces.MultiDiscrete(nvec=[2, 2]) + self._serializer = space_serializer.create(self._space, vocab_size=2) + + def _sample_batch(self): + return np.reshape(self._space.sample(), (1,) + self._space.shape) + + def test_representation_length(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + self.assertEqual( + representation.shape, (1, self._serializer.representation_length)) + + def test_commutes(self): + input_array = self._sample_batch() + representation = self._serializer.serialize(input_array) + output_array = self._serializer.deserialize(representation) + np.testing.assert_array_almost_equal(input_array, output_array) + + def test_representation_changes(self): + array1 = self._sample_batch() + array2 = 1 - array1 + (repr1, repr2) = tuple(map(self._serializer.serialize, (array1, array2))) + self.assertFalse(np.array_equal(repr1, repr2)) + + def test_significance_map(self): + np.testing.assert_array_equal(self._serializer.significance_map, [0, 0]) + + +if __name__ == "__main__": + test.main() diff --git a/trax/rl/trainers.py b/trax/rl/trainers.py new file mode 100644 index 000000000..d455a9164 --- /dev/null +++ b/trax/rl/trainers.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trainers defined in trax.rl.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gin + +from trax.rl import ppo_trainer +from trax.rl import simple_trainer + + +# Ginify +def trainer_configure(*args, **kwargs): + kwargs["module"] = "trax.rl.trainers" + kwargs["blacklist"] = ["train_env", "eval_env", "output_dir"] + return gin.external_configurable(*args, **kwargs) + + +# pylint: disable=invalid-name +PPO = trainer_configure(ppo_trainer.PPO) +SimPLe = trainer_configure(simple_trainer.SimPLe) diff --git a/trax/rl_trainer.py b/trax/rl_trainer.py new file mode 100644 index 000000000..43bf714d5 --- /dev/null +++ b/trax/rl_trainer.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Trainer for RL environments. + +For now we only support PPO as RL algorithm. + +Sample invocation: + +TRAIN_BATCH_SIZE=32 +python trax/rl_trainer.py \ + --config_file=trax/rl/configs/acrobot.gin \ + --train_batch_size=${TRAIN_BATCH_SIZE} \ + --output_dir=${HOME}/ppo_acrobot \ + --vmodule=*/tensor2tensor/*=1 \ + --alsologtostderr +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing +import os + +from absl import app +from absl import flags +from absl import logging +import gin +import jax +from jax.config import config +from tensor2tensor import envs # pylint: disable=unused-import +from tensor2tensor.envs import env_problem_utils +from tensor2tensor.rl.google import atari_utils # GOOGLE-INTERNAL: +from trax import rl # pylint: disable=unused-import +from trax.rl import envs as rl_envs # pylint: disable=unused-import +from trax.rl import trainers as rl_trainers + + +FLAGS = flags.FLAGS + +flags.DEFINE_boolean( + "jax_debug_nans", False, + "Setting to true will help to debug nans and disable jit.") +flags.DEFINE_boolean("disable_jit", False, "Setting to true will disable jit.") + +flags.DEFINE_string("output_dir", "", "Output dir.") +flags.DEFINE_string("envs_output_dir", "", "Output dir for the envs.") +flags.DEFINE_multi_string("config_file", None, + "Configuration file with parameters (.gin).") +flags.DEFINE_multi_string("config", None, + "Configuration parameters (gin string).") +flags.DEFINE_bool("use_tpu", False, "Whether we're running on TPU.") +flags.DEFINE_bool("xm", False, "Copy atari roms?") +flags.DEFINE_integer("train_batch_size", 32, + "Number of parallel environments during training.") +flags.DEFINE_integer("eval_batch_size", 4, "Batch size for evaluation.") +flags.DEFINE_boolean("parallelize_envs", False, + "If true, sets parallelism to number of cpu cores.") +flags.DEFINE_string("trajectory_dump_dir", "", + "Directory to dump trajectories to.") + +# TODO(afrozm): Find a better way to do these configurations. +flags.DEFINE_string("train_server_bns", "", "Train Server's BNS.") +flags.DEFINE_string("eval_server_bns", "", "Eval Server's BNS.") + +flags.DEFINE_bool("async_mode", False, "Async mode.") + + +# Not just "train" to avoid a conflict with trax.train in GIN files. +@gin.configurable(blacklist=[ + "output_dir", "train_batch_size", "eval_batch_size", "trajectory_dump_dir" +]) +def train_rl( + output_dir, + train_batch_size, + eval_batch_size, + env_name="ClientEnv-v0", + max_timestep=None, + clip_rewards=False, + rendered_env=False, + resize_dims=(105, 80), + trainer_class=rl_trainers.PPO, + n_epochs=10000, + trajectory_dump_dir=None, +): + """Train the RL agent. + + Args: + output_dir: Output directory. + train_batch_size: Number of parallel environments to use for training. + eval_batch_size: Number of parallel environments to use for evaluation. + env_name: Name of the environment. + max_timestep: Int or None, the maximum number of timesteps in a trajectory. + The environment is wrapped in a TimeLimit wrapper. + clip_rewards: Whether to clip and discretize the rewards. + rendered_env: Whether the environment has visual input. If so, a + RenderedEnvProblem will be used. + resize_dims: Pair (height, width), dimensions to resize the visual + observations to. + trainer_class: RLTrainer class to use. + n_epochs: Number epochs to run the training for. + trajectory_dump_dir: Directory to dump trajectories to. + """ + + if FLAGS.jax_debug_nans: + config.update("jax_debug_nans", True) + + if FLAGS.use_tpu: + config.update("jax_platform_name", "tpu") + else: + config.update("jax_platform_name", "gpu") + + + # TODO(pkozakowski): Find a better way to determine this. + train_env_kwargs = {} + eval_env_kwargs = {} + if "OnlineTuneEnv" in env_name: + envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, "envs") + train_env_output_dir = os.path.join(envs_output_dir, "train") + eval_env_output_dir = os.path.join(envs_output_dir, "eval") + train_env_kwargs = {"output_dir": train_env_output_dir} + eval_env_kwargs = {"output_dir": eval_env_output_dir} + + if "ClientEnv" in env_name: + train_env_kwargs["per_env_kwargs"] = [{ + "remote_env_address": os.path.join(FLAGS.train_server_bns, str(replica)) + } for replica in range(train_batch_size)] + + eval_env_kwargs["per_env_kwargs"] = [{ + "remote_env_address": os.path.join(FLAGS.eval_server_bns, str(replica)) + } for replica in range(eval_batch_size)] + + # TODO(afrozm): Should we leave out some cores? + parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 + + train_env = env_problem_utils.make_env( + batch_size=train_batch_size, + env_problem_name=env_name, + resize=rendered_env, + resize_dims=resize_dims, + max_timestep=max_timestep, + clip_rewards=clip_rewards, + parallelism=parallelism, + use_tpu=FLAGS.use_tpu, + **train_env_kwargs) + assert train_env + + eval_env = env_problem_utils.make_env( + batch_size=eval_batch_size, + env_problem_name=env_name, + resize=rendered_env, + resize_dims=resize_dims, + max_timestep=max_timestep, + clip_rewards=clip_rewards, + parallelism=parallelism, + use_tpu=FLAGS.use_tpu, + **eval_env_kwargs) + assert eval_env + + def run_training_loop(): + """Runs the training loop.""" + logging.info("Starting the training loop.") + + trainer = trainer_class( + output_dir=output_dir, + train_env=train_env, + eval_env=eval_env, + trajectory_dump_dir=trajectory_dump_dir, + async_mode=FLAGS.async_mode, + ) + trainer.training_loop(n_epochs=n_epochs) + + if FLAGS.jax_debug_nans or FLAGS.disable_jit: + with jax.disable_jit(): + run_training_loop() + else: + run_training_loop() + + +def main(argv): + del argv + logging.info("Starting RL training.") + + gin_configs = FLAGS.config or [] + gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) + + train_rl( + output_dir=FLAGS.output_dir, + train_batch_size=FLAGS.train_batch_size, + eval_batch_size=FLAGS.eval_batch_size, + trajectory_dump_dir=(FLAGS.trajectory_dump_dir or None), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/trainer.py b/trax/trainer.py new file mode 100644 index 000000000..495b72833 --- /dev/null +++ b/trax/trainer.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax trainer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import os + +from absl import app +from absl import flags +from absl import logging + +import gin +import jax +import tensorflow as tf +from trax import backend +from trax import trainer_lib + +FLAGS = flags.FLAGS + +flags.DEFINE_string("dataset", None, "Which dataset to use.") +flags.DEFINE_string("model", None, "Which model to train.") +flags.DEFINE_string("data_dir", None, "Path to the directory with data.") +flags.DEFINE_string("output_dir", None, + "Path to the directory to save logs and checkpoints.") +flags.DEFINE_multi_string("config_file", None, + "Configuration file with parameters (.gin).") +flags.DEFINE_multi_string("config", None, + "Configuration parameters (gin string).") +flags.DEFINE_integer("log_level", logging.INFO, "Log level.") +flags.DEFINE_bool("use_tpu", False, "Whether we're running on TPU.") +flags.DEFINE_bool("enable_eager_execution", True, + "Whether we're running TF in eager mode.") +flags.DEFINE_bool("tf_xla", True, "Whether to turn on XLA for TF.") +flags.DEFINE_bool("tf_opt_pin_to_host", False, "Whether to turn on TF " + "pin-to-host optimization.") +flags.DEFINE_bool("tf_opt_layout", False, "Whether to turn on TF layout " + "optimization.") + + +def _default_output_dir(): + """Default output directory.""" + try: + dataset_name = gin.query_parameter("inputs.dataset_name") + except ValueError: + dataset_name = "random" + dir_name = "{model_name}_{dataset_name}_{timestamp}".format( + model_name=gin.query_parameter("train.model").configurable.name, + dataset_name=dataset_name, + timestamp=datetime.datetime.now().strftime("%Y%m%d_%H%M"), + ) + dir_path = os.path.join("~", "trax", dir_name) + print() + trainer_lib.log("No --output_dir specified") + return dir_path + + +def _setup_gin(): + """Setup gin configuration.""" + # Imports for configurables + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + from trax import models as _trax_models + from trax import optimizers as _trax_opt + # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable + + configs = FLAGS.config or [] + # Override with --dataset and --model + if FLAGS.dataset: + configs.append("inputs.dataset_name='%s'" % FLAGS.dataset) + if FLAGS.data_dir: + configs.append("inputs.data_dir='%s'" % FLAGS.data_dir) + if FLAGS.model: + configs.append("train.model=@trax.models.%s" % FLAGS.model) + gin.parse_config_files_and_bindings(FLAGS.config_file, configs) + + + + +def main(_): + + logging.set_verbosity(FLAGS.log_level) + + if FLAGS.enable_eager_execution: + tf.enable_eager_execution() + + if FLAGS.tf_xla: + tf.config.optimizer.set_jit(True) + + tf.config.optimizer.set_experimental_options( + {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host} + ) + + tf.config.optimizer.set_experimental_options( + {"layout_optimizer": FLAGS.tf_opt_layout} + ) + + + _setup_gin() + + if FLAGS.enable_eager_execution and backend.get_name() in ("numpy", "jax"): + # Numpy backend doesn't benefit from having the input pipeline run on GPU, + # and jax backend has GPU memory contention if TF uses the GPU. Gin must be + # set up first before determining the backend. + tf.config.experimental.set_visible_devices([], "GPU") + + # Setup output directory + output_dir = FLAGS.output_dir or _default_output_dir() + trainer_lib.log("Using --output_dir %s" % output_dir) + output_dir = os.path.expanduser(output_dir) + + # If on TPU, let JAX know. + if FLAGS.use_tpu: + jax.config.update("jax_platform_name", "tpu") + + trainer_lib.train(output_dir=output_dir) + + +if __name__ == "__main__": + app.run(main) diff --git a/trax/trainer_lib.py b/trax/trainer_lib.py new file mode 100644 index 000000000..27093b6ff --- /dev/null +++ b/trax/trainer_lib.py @@ -0,0 +1,904 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Trax main training functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import itertools +import os +import random +import sys +import time + +from absl import logging + +import gin + +import jax +from jax import lax +import numpy +import six +import tensorflow as tf +from tensorflow.io import gfile +from trax import backend +from trax import history as trax_history +from trax import inputs as trax_inputs +from trax import jaxboard +from trax import layers +from trax import learning_rate as lr +from trax import optimizers as trax_opt +from trax import utils +from trax.backend import numpy as np +from trax.backend import random as jax_random + + +def _stack_inputs_targets_and_get_predictions(inputs_and_targets): + """Helper to stack inputs and targets and retrieve predictions from output.""" + # Inputs and targets can be lists - we build a flat one to input to the model. + model_inp = [] + for x in inputs_and_targets: + if not isinstance(x, (list, tuple)): + model_inp.append(x) + else: + model_inp.extend(x) + # We retrieve as many predictions from model output as many there were inputs. + inp = inputs_and_targets[0] + inp_len = len(inp) if isinstance(inp, (list, tuple)) else 1 + get_pred = lambda x: x[0] if inp_len == 1 else x[:inp_len] + return tuple(model_inp), get_pred + + +def log(s, stdout=True): + logging.info(s) + if stdout: + print(s) + sys.stdout.flush() + + +def step_log(step, s): + log("Step % 6d: %s" % (step, s)) + + +State = collections.namedtuple("_State", [ + "step", # Current training step number. + "opt_state", # OptState. + "history", # trax.history.History. + "model_state", +]) + + +OptState = collections.namedtuple("_OptState", [ + "params", # Model parameters. + "slots", # Per-parameter optimizer state, e.g. gradient moments. + "opt_params", # Optimizer (hyper)parameters, e.g. learning rate, momentum. +]) + + +def restore_state(output_dir): + """Restore State.""" + params_file = os.path.join(output_dir, "model.pkl") + if not gfile.exists(params_file): + return State(step=None, opt_state=None, history=trax_history.History(), + model_state=None) + + pkl_module = utils.get_pickle_module() + with gfile.GFile(params_file, "rb") as f: + (opt_state, step, history, model_state) = pkl_module.load(f) + log("Model loaded from %s at step %d" % (params_file, step)) + logging.debug("From loaded model : history = %s", history) + return State(step=step, opt_state=OptState(*opt_state), history=history, + model_state=model_state) + + +def _save_gin(output_dir, sw=None): + config_path = os.path.join(output_dir, "config.gin") + config_str = gin.operative_config_str() + with gfile.GFile(config_path, "w") as f: + f.write(config_str) + if sw: + sw.text("gin_config", + jaxboard.markdownify_operative_config_str(config_str)) + + +def save_state(state, output_dir, keep=False): + """Save State and optionally gin config.""" + pkl_module = utils.get_pickle_module() + params_file = os.path.join(output_dir, "model.pkl") + with gfile.GFile(params_file, "wb") as f: + pkl_module.dump((tuple(state.opt_state), state.step, state.history, + state.model_state), f) + if keep: + params_file = os.path.join(output_dir, "model_{}.pkl".format(state.step)) + with gfile.GFile(params_file, "wb") as f: + pkl_module.dump((tuple(state.opt_state), state.step, state.history, + state.model_state), f) + log("Model saved to %s" % params_file, stdout=False) + + +def _save_replicated(opt_state, step, history, model_state, n_devices, + output_dir, keep): + """Save state but given a possibly replicated opt_state.""" + if n_devices > 1: + first_replica = lambda x: x[0] + opt_state = OptState(*layers.nested_map(opt_state, first_replica)) + # This line, while optional, allows JAX to transfer arrays from the device to + # the host in parallel, which is particularly important for cloud TPU. + if backend.get_name() == "jax": + opt_state = jax.device_get(opt_state) + save_state(State(opt_state=opt_state, step=step, history=history, + model_state=model_state), output_dir, keep=keep) + + +def _print_n_params(opt_state, n_devices, step): + """Print out the number of parameters.""" + sizes = layers.sizes(opt_state.params) + if n_devices > 1: + unreplicate = lambda x: x[0] + single_params = layers.nested_map(opt_state.params, unreplicate) + sizes = layers.sizes(single_params) + total_size = layers.nested_reduce(sizes, sum) + step_log(step, "Total trainable parameters size: %d" % total_size) + + +# Metrics to calculate and report. +_METRICS = { + "accuracy": layers.AccuracyScalar, + "neg_log_perplexity": layers.NegLogPerplexityScalar, + "loss": layers.CrossEntropyLossScalar, +} + + +def evaluation_round(inputs_stream, metric_names, eval_fn, params, state, rng): + """Evaluate. + + Args: + inputs_stream: iterable of inputs to evaluate on. + metric_names: list of strings, the order in which eval_fn returns metrics. + eval_fn: metric function, which takes inputs and predictions (and + params, state, rng) and returns a tuple of scalar metric values. + params: params for each f in eval_fns. + state: state for each f in eval_fns. + rng: random number generator. + + Returns: + metrics: dict from metric name to metric value averaged over the number of + inputs. + state: end state for `predict_fn`. + """ + metrics = collections.defaultdict(float) + count = 0 + for inp in inputs_stream: + count += 1 + rng, subrng = jax_random.split(rng) + metric_values = eval_fn(inp, params=params, state=state, rng=subrng) + try: + metric_values = list(metric_values) + except TypeError: + metric_values = [float(metric_values)] + for m, v in zip(metric_names, metric_values): + metrics[m] += v + return {m: v / count for (m, v) in six.iteritems(metrics)}, state + + +def log_metrics(metrics, summ_writer, log_prefix, step, history=None): + """Log metrics to summary writer and history.""" + rjust_len = max([0] + [len(name) for name in metrics]) + for name, value in six.iteritems(metrics): + step_log(step, "%s %s | % .8f" % ( + log_prefix.ljust(5), name.rjust(rjust_len), value)) + full_name = "metrics/" + name + if history: + history.append(log_prefix, full_name, step, value) + if summ_writer: + summ_writer.scalar(full_name, value, step) + + +def get_random_number_generator_and_set_seed(seed=None): + """Get a JAX random number generator and set random seed everywhere.""" + random.seed(seed) + # While python random accepts None as seed and uses time/os seed then, + # some other functions expect integers so we create one here. + if seed is None: + seed = random.randint(0, 2**31 - 1) + tf.set_random_seed(seed) + numpy.random.seed(seed) + return jax_random.get_prng(seed) + + +def epochs(total_steps, steps_to_skip, epoch_steps): + """Generates the number of steps in each epoch before reaching total_steps. + + Args: + total_steps: int, total number of steps. + steps_to_skip: int, number of steps to skip because of a restart. + epoch_steps: iterable of int, numbers of steps in each epoch. + + Yields: + epoch_steps: int, number of steps in this epoch + """ + steps_to_go = total_steps - steps_to_skip + epoch_steps = iter(epoch_steps) + + # Remove the desired number of steps from the stream. + for steps_this_epoch in epoch_steps: + if steps_this_epoch > steps_to_skip: + # Put back the number of steps left in the unfinished epoch. + epoch_steps = itertools.chain( + [steps_this_epoch - steps_to_skip], epoch_steps) + if steps_this_epoch >= steps_to_skip: + break + steps_to_skip -= steps_this_epoch + + # Yield the remaining steps per epoch up to total_steps. + for steps_this_epoch in epoch_steps: + steps_this_epoch = min(steps_this_epoch, steps_to_go) + yield steps_this_epoch + steps_to_go -= steps_this_epoch + if steps_to_go == 0: + break + + +@gin.configurable +def _jit_predict_fn(model_predict, metric_fn, n_devices, jit=True): + """Returns a JIT-compiled predict function (unless jit=False).""" + model_predict = layers.Serial([model_predict, metric_fn]) + + if n_devices == 1: + return backend.jit(model_predict) if jit else model_predict + + # Multi-devices, pmap and run. + @functools.partial(backend.pmap, axis_name="batch") + def mapped_predict(x, params, state, rng): + return model_predict(x, params=params, state=state, rng=rng) + + def predict(x, params=(), state=(), rng=None): + """Predict function jited and parallelized as requested.""" + pred = mapped_predict( + reshape_by_device(x, n_devices), + params, + state, + jax_random.split(rng, n_devices)) + # Need to reduce the [device, per-device-batch, ...] tensors back to + # a [batch, ...] tensor. The tensors may be nested. + def combine(x): + if len(x.shape) > 1: + batch_size = x.shape[0] * x.shape[1] + return np.reshape(x, [batch_size] + list(x.shape[2:])) + # TODO(lukaszkaiser): is returning averages for scalars the right choice? + # If it is only scalar, return the average. + return np.mean(x, axis=0) + return layers.nested_map(pred, combine) + + return predict + + +@gin.configurable +def _jit_update_fn(predict_fn, loss_fn, optimizer, n_devices, jit=True): + """Returns a (JIT-compiled) function that computes updates for one step.""" + model_and_loss = layers.Serial([predict_fn, loss_fn]) + # Gradients are always wrt. the first argument, so putting params first. + def model_and_loss_call(params, batch, state, rng): + res = model_and_loss(batch, params=params, state=state, rng=rng) + return res, model_and_loss.state + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + def single_update(i, opt_state, batch, state, rng): + params, slots, opt_params = opt_state + rng, subrng = jax_random.split(rng[0]) + grad_fn = backend.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(params, batch, state, rng) + return optimizer.tree_update( + i, grads, params, slots, opt_params), state, [subrng] + return backend.jit(single_update) if jit else single_update + + # Else, for n_devices > 1: + @functools.partial(backend.pmap, axis_name="batch") + def mapped_update(i, opt_state, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + params, slots, opt_params = opt_state + rng, subrng = jax_random.split(rng) + grad_fn = backend.grad(model_and_loss_call, has_aux=True) + grads, state = grad_fn(params, batch, state, rng) + grads = jax.tree_util.tree_map( + lambda g: lax.psum(g, "batch"), grads) + return optimizer.tree_update( + i, grads, params, slots, opt_params), state, subrng + + def update(i, opt_state, batch, state, rng): + return mapped_update(numpy.repeat(i, n_devices), opt_state, batch, state, + rng) + + return update + + +@gin.configurable +def _jit_compute_loss_fn(predict_fn, loss_fn, n_devices, jit=True): + """Returns a (JIT-compiled) function that computes the loss for one step.""" + if n_devices == 1: # TODO(lukaszkaiser): remove branch when not needed. + def single_compute_loss(opt_state, batch, state, rng): + rng, subrng = jax_random.split(rng[0]) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, [subrng] + return backend.jit(single_compute_loss) if jit else single_compute_loss + + # Else, for n_devices > 1: + @functools.partial(backend.pmap, axis_name="batch") + def mapped_compute_loss(opt_state, batch, state, rng): + """This is a multi-device version of the update function above.""" + # We assume all tensors have the first dimension = n_devices. + rng, subrng = jax_random.split(rng) + loss_val, state = loss_fn(opt_state[0], batch, predict_fn, state, rng) + return loss_val, state, subrng + + def compute_loss(opt_state, batch, state, rng): + return mapped_compute_loss( + opt_state, reshape_by_device(batch, n_devices), state, rng) + + return compute_loss + + +@gin.configurable +def _is_jit_init(value=True): + return value + + +def _reshape_by_device_single(x, n_devices): + """Reshape x into a shape [n_devices, ...].""" + x_shape = list(x.shape) + batch_size = x_shape[0] + batch_size_per_device = batch_size // n_devices + # We require that n_devices divides batch_size evenly. + if batch_size_per_device * n_devices != batch_size: + logging.fatal( + "We require that n_devices[%d] divides batch_size[%d] evenly.", + n_devices, batch_size) + # New shape. + new_shape_prefix = [n_devices, batch_size_per_device] + return np.reshape(x, new_shape_prefix + x_shape[1:]) + + +def reshape_by_device(x, n_devices): + """Reshape possibly nested x into a shape [n_devices, ...].""" + return layers.nested_map( + x, lambda x: _reshape_by_device_single(x, n_devices)) + + +def multi_device_put(x, devices=None, reuse=True): + """Memory efficient multi-device replication in JAX. + + Args: + x: jax DeviceArray or numpy ndarray to be replicated. + devices: a jax.devices() list or subset thereof of devices to + replicate onto. Should match the list passed to any pmaps + ingesting the replicated array. + reuse: bool. If x is a DeviceArray whether to reuse its backing + device_buffer in the resulting ShardedDeviceArray. + + Returns: + A ShardedDeviceArray with dtype = x.dtype and shape = + (n_devices,) + x.shape that's backed by replica + device_buffers on each device. + """ + # Convert _FilledConstants that don't have device_buffer, etc. + if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck + x = np.array(x) + if not devices: + devices = jax.devices() + n_devices = len(devices) + x_aval = jax.xla.abstractify(x) + broadcast_x_aval = jax.abstract_arrays.ShapedArray( + (n_devices,) + x_aval.shape, + x_aval.dtype) + if reuse: + other_device_ordinals = [dv.id for dv in jax.devices() + if dv != x.device_buffer.device()] + broadcast_buffers = ([x.device_buffer,] + + [jax.xla.xc.Buffer.from_pyval(x, device=i) + for i in other_device_ordinals]) + else: + broadcast_buffers = [jax.xla.xc.Buffer.from_pyval(x, device=i) + for i in range(n_devices)] + return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers) + + +def _repeat_stream(stream): + """Repeat a stream indefinitely.""" + while True: + for example in stream(): + yield example + + +@gin.configurable(whitelist=[]) +class Trainer(object): + """Trax trainer. + + A trainer allows to make training steps, train for full epochs, + save the training state and access evaluation data. + """ + + def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, + output_dir=None, random_seed=None, n_devices=None, + save_steps=None, should_save_checkpoints=True, + should_write_summaries=True, has_weights=False, + nontrainable_param_map=None, mask_id=None): + if save_steps is None: + save_steps = [] + self._save_steps = save_steps + self._should_save_checkpoints = should_save_checkpoints + self._should_write_summaries = should_write_summaries + self._has_weights = has_weights + self._mask_id = mask_id + loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) + device_count = jax.lib.xla_bridge.device_count() + n_devices = n_devices or device_count + # TODO(lukaszkaiser): remove this restriction when possible. + if n_devices != device_count: + raise ValueError("Jax cannot work yet with n_devices != all devices: " + "%d != %d" % (n_devices, device_count)) + self._n_devices = n_devices + rng = get_random_number_generator_and_set_seed(random_seed) + inputs = inputs(n_devices) + self._inputs = inputs + + # Initialize the learning rate to a dummy value. It will be set in reset(). + opt = optimizer(learning_rate=0.0) + + # Setup the model. + model_train = model(mode="train") + model_predict_eval = model(mode="eval") + + # Setup state. + rng, init_rng = jax_random.split(rng) + self._rngs = jax_random.split(rng, n_devices) + first_shape = inputs.input_shape[0] + # If the inputs are a tuple/list, add [None] (batch) to each element. + if isinstance(first_shape, (list, tuple)): + model_input_shape = tuple( + tuple([None] + list(shape)) for shape in inputs.input_shape) + model_target_shape = tuple( + tuple([None] + list(shape)) for shape in inputs.target_shape) + else: # Otherwise just add [None] to the input shape. + model_input_shape = tuple([None] + list(inputs.input_shape)) + model_target_shape = tuple([None] + list(inputs.target_shape)) + # Change all None to 1 in input and target shape. + model_input_shape = layers.nested_map( + model_input_shape, lambda x: x if x else 1) + model_target_shape = layers.nested_map( + model_target_shape, lambda x: x if x else 1) + def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, + target_dtype, rng): + """Returns optimizer and model states suitable for training a model.""" + # Combine inputs and targets on the stack. + if not isinstance(input_dtype, (list, tuple)): + input_dtype = [input_dtype] + input_shape = [input_shape] + if not isinstance(target_dtype, (list, tuple)): + target_dtype = [target_dtype] + target_shape = [target_shape] + full_type = list(input_dtype) + list(target_dtype) + full_shape = list(input_shape) + list(target_shape) + if self._has_weights: + full_shape += list(target_shape) + full_type += [np.float32 for _ in target_dtype] + # We need to create a new model instance and not reuse `model_train` here, + # because `m.initialize` puts cached parameter values in `m` and hence the + # next call of `m.initialize` will give wrong results. + m = layers.Serial([model(mode="train"), loss_fn]) + params, state = m.initialize_once(full_shape, full_type, rng) + (slots, opt_params) = opt.tree_init(params) + return (OptState(params, slots, opt_params), state) + if _is_jit_init(): + # JIT parameter initialization to avoid memory fragmentation + new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state, + static_argnums=(0, 1, 2, 3)) + self._new_opt_state_and_model_state = ( + lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda + model_input_shape, self._inputs.input_dtype, + model_target_shape, self._inputs.target_dtype, init_rng)) + + # jit model_predict and update so they're fast + # TODO(lukaszkaiser): the code below creates a layer computing + # multiple metrics from a single model output; re-factor for clarity. + dup_layer = layers.Dup3() if self._has_weights else layers.Dup2() + def lower(layer): + """Apply layer below the current inputs, targets, and possibly weights.""" + if self._has_weights: + # Apply layer below inputs, targets, and loss weights. + return layers.Parallel([], [], [], layer) + else: + # Apply layer below inputs and targets. + return layers.Parallel([], [], layer) + metrics_layer = [] + self._metrics = list(sorted(_METRICS.keys())) + for i, m in enumerate(reversed(self._metrics)): + metric = _METRICS[m](has_weights=self._has_weights, mask_id=self._mask_id) + if i != len(self._metrics) - 1: + metrics_layer.append(dup_layer) + metrics_layer.append(lower(metric)) + else: + metrics_layer.append(metric) + # TODO(lukaszkaiser): clean this up once layer API stabilizes. + # For now, we need to initialize metric layers somehow, so here we go. + # We assume that they do not have any parameters, so this is a dummy. + dummy_shape = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,)) + dummy_type = [np.float32] * (3 if self._has_weights else 2) + metrics_layer = layers.Serial(metrics_layer) + metrics_params, metrics_state = metrics_layer.initialize_once( + dummy_shape, tuple(dummy_type), init_rng) + self._metrics_params = layers.nested_map( + metrics_params, self._maybe_replicate) + self._metrics_state = layers.nested_map( + metrics_state, self._maybe_replicate) + self._jit_eval = _jit_predict_fn( + model_predict_eval, metrics_layer, n_devices) + self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) + + self._model_train = model_train + self._model_predict_eval = model_predict_eval + self._loss_fn = loss_fn + # TODO(pkozakowski): "Learning rate schedules" are currently able to control + # control all optimizer parameters and model state, so let's rename them + # accordingly. + self._lr_schedule = lr_schedule + + if nontrainable_param_map is None: + nontrainable_param_map = {} + self._nontrainable_param_map = nontrainable_param_map + + # Those fields will be set in reset(). + self._output_dir = None + self._train_sw = None + self._eval_sw = None + self._history = None + self._lr_fn = None + self._opt_state = None + self._step = None + self._model_state = None + + if output_dir is not None: + self.reset(output_dir) + + def reset(self, output_dir): + """Reset the model parameters. + + Restores the parameters from the given output_dir if a checkpoint exists, + otherwise randomly initializes them. + + Does not re-jit the model. + + Args: + output_dir: Output directory. + """ + self._output_dir = output_dir + gfile.makedirs(output_dir) + # Create summary writers and history. + if self._should_write_summaries: + self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) + self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) + + # Reset the train and eval streams. + self._train_stream = self._inputs.train_stream() + # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval + # set by adding a padding and stopping the stream when too large. + self._eval_stream = _repeat_stream(self._inputs.eval_stream) + self._train_eval_stream = _repeat_stream(self._inputs.train_eval_stream) + + # Restore the training state. + state = restore_state(output_dir) + self._step = state.step or 0 + history = state.history + self._lr_fn = self._lr_schedule(history) + self._history = history + if state.opt_state: + opt_state = state.opt_state + model_state = state.model_state + else: + opt_state, model_state = self._new_opt_state_and_model_state() + model_state = layers.nested_map( + model_state, self._maybe_replicate) + self._opt_state = OptState(*layers.nested_map( + opt_state, self._maybe_replicate)) + self._model_state = model_state + if not state.opt_state: + self._maybe_save_state(keep=False) + + self.update_nontrainable_params() + + @property + def step(self): + return self._step + + @property + def n_devices(self): + return self._n_devices + + @property + def state(self): + return State( + opt_state=self._opt_state, step=self._step, history=self._history, + model_state=self._model_state) + + @property + def nontrainable_params(self): + # TODO(lukaszkaiser): it makes no sense to use an accelerator (e.g. TPU) + # in op-by-op mode just to compute the learning rate. However, there + # should be a cleaner approach that forceably swapping out the backend. + with backend.use_backend("numpy"): + return self._lr_fn(self._step) + + def _maybe_replicate(self, x): + if self._n_devices > 1: + if backend.get_name() == "jax": + return multi_device_put(x) + else: + return np.broadcast_to(x, (self._n_devices,) + x.shape) + else: + return x + + def _maybe_save_state(self, keep): + if self._should_save_checkpoints: + _save_replicated(self._opt_state, self._step, self._history, + self._model_state, self._n_devices, self._output_dir, + keep) + + def save_gin(self): + _save_gin(self._output_dir, self._train_sw) + + def print_n_params(self): + _print_n_params(self._opt_state, self._n_devices, self._step) + + def _map_to_state_dicts(self, f): + """Map the function f to all dicts in model state.""" + def nested_map(x, f): + if isinstance(x, list): + return [nested_map(y, f) for y in x] + if isinstance(x, tuple): + return tuple([nested_map(y, f) for y in x]) + if isinstance(x, dict) and len(x) == 1: + return f(x) + return x + return nested_map(self._model_state, f) + + def _state_dicts_update(self, state_dict): + assert len(state_dict.keys()) == 1 + key = list(state_dict.keys())[0] + value = np.array(state_dict[key]) + return {key: np.array(self.update_model_state(key, value))} + + def update_model_state(self, key, value): + """Updates model state based on nontrainable_params.""" + # Translate model state keys to nontrainable param names. + if key in self._nontrainable_param_map: + param_name = self._nontrainable_param_map[key] + else: + # If a key is not in mapping, it stays the same. + param_name = key + if param_name in self.nontrainable_params: + if self._step == 0: + log("Mapping model state key {} to nontrainable param {}.".format( + key, param_name + )) + return self._maybe_replicate( + np.array(self.nontrainable_params[param_name]) + ) + return value + + def _train_step(self, next_train_batch): + """Run one training step and update self._opt_state.""" + # Calculate the current optimizer parameters. + # TODO(pkozakowski): Optimizer parameters get polluted with model state, + # which doesn't break anything but is weird. Filter it out. + opt_param_updates = layers.nested_map( + self.nontrainable_params, lambda x: self._maybe_replicate(np.array(x)) + ) + opt_state = self._opt_state + opt_state.opt_params.update(opt_param_updates) + + # Run the update. + (params, slots), self._model_state, self._rngs = self._jit_update_fn( + self._step, opt_state, next_train_batch, self._model_state, self._rngs) + self._model_state = self._map_to_state_dicts(self._state_dicts_update) + self._opt_state = opt_state._replace(params=params, slots=slots) + self._step += 1 + + def train_epoch(self, epoch_steps, eval_steps): + """Train for one epoch.""" + # Log separator + print() + + # Timer + start_time = time.time() + + for _ in range(epoch_steps): + # Train + next_train_batch = next(self._train_stream) + if self._n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. + next_train_batch = reshape_by_device(next_train_batch, self._n_devices) + + self._train_step(next_train_batch) + + if self._step in self._save_steps: + self._maybe_save_state(keep=True) + + # Log nontrainable params (learning rate, dropout etc.) + if (self._step == 1 or self._step % 10 == 0) and self._train_sw: + for (name, value) in self.nontrainable_params.items(): + self._train_sw.scalar("training/{}".format(name), value) + + # Timer + epoch_time = time.time() - start_time + step_log(self._step, "Ran %d train steps in %0.2f secs" % + (epoch_steps, epoch_time)) + if epoch_steps > 1 and self._train_sw: + self._train_sw.scalar("training/steps per second", + epoch_steps / epoch_time, step=self._step) + + # Evaluate in parallel + self.evaluate(eval_steps) + + # Save state + self._maybe_save_state(keep=False) + + # Flush summary writers + if self._train_sw: + self._train_sw.flush() + self._eval_sw.flush() + + def evaluate(self, eval_steps): + """Evaluate the model and log metrics.""" + _, rng = jax_random.split(self._rngs[0]) + # TODO(lukaszkaiser): both model state and parameters by default include + # the loss layer. Currently, we access the pure-model parameters by just + # indexing, [0] here. But we should make it more explicit in a better API. + params = (self._opt_state[0][0], self._metrics_params) + state = (self._model_state[0], self._metrics_state) + step_log(self._step, "Evaluation") + train_eval_slice = itertools.islice(self._train_eval_stream, eval_steps) + train_metrics, _ = evaluation_round( + train_eval_slice, self._metrics, self._jit_eval, params, state, rng) + log_metrics(train_metrics, self._train_sw, "train", + self._step, history=self._history) + eval_slice = itertools.islice(self._eval_stream, eval_steps) + eval_metrics, _ = evaluation_round( + eval_slice, self._metrics, self._jit_eval, params, state, rng) + log_metrics(eval_metrics, self._eval_sw, "eval", + self._step, history=self._history) + step_log(self._step, "Finished evaluation") + + # Save the optimizer params in the history + for (name, value) in self.nontrainable_params.items(): + self._history.append("train", "training/{}".format(name), self._step, + value) + + def update_nontrainable_params(self): + self._lr_fn = self._lr_schedule(self._history) + + def save_computation_graphs(self, save_backward_graph): + """Dump computation graphs to files.""" + if self._n_devices != 1: + return # TODO(lukaszkaiser): make this work with more devices. + next_train_batch = next(self._train_stream) + output_dir = self._output_dir + if self._n_devices > 1: + next_train_batch = reshape_by_device(next_train_batch, self._n_devices) + params = self._opt_state[0][0] + forward_computation = jax.xla_computation(self._model_predict_eval)( + next_train_batch, params=params, state=self._model_state[0], + rng=self._rngs[0]) + with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: + f.write(forward_computation.GetHloText()) + with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: + f.write(forward_computation.GetHloDotGraph()) + backward_computation = jax.xla_computation(self._jit_update_fn)( + self._step, self._opt_state, next_train_batch, self._model_state, + self._rngs) + with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: + f.write(backward_computation.GetHloText()) + if save_backward_graph: # Backward graphs can be large so we guard it. + with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: + f.write(backward_computation.GetHloDotGraph()) + + +@gin.configurable(blacklist=["output_dir"]) +def train(output_dir, + model=gin.REQUIRED, + loss_fn=layers.CrossEntropyLossScalar, + inputs=trax_inputs.inputs, + optimizer=trax_opt.Adafactor, + lr_schedule=lr.MultifactorSchedule, + trainer_class=Trainer, + train_steps=1000, + save_steps=None, + eval_steps=10, + eval_frequency=100, + n_devices=None, + random_seed=None, + save_graphs=True, + save_backward_graph=False, + has_weights=False, + nontrainable_param_map=None, + mask_id=None): + """Train the model on the inputs. + + Args: + output_dir: Directory where to put the logs and checkpoints. + model: The model to train as a callable returning 2 callables, an init_fn + and apply_fn. + loss_fn: callable with signature: params, trax.inputs.Inputs, model, state, + rng -> loss. + inputs: callable returning trax.inputs.Inputs. + optimizer: The optimizer (see optimizers/base.py for signature). + lr_schedule: A learning rate schedule as a function that takes history and + returns a function from step to learning rate (a float). + trainer_class: The trainer class to use. + train_steps: int, total number of training steps. + save_steps: list of integers. Keep a model file at each of the supplied save + steps. + eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. + eval_frequency: int, how often to run evaluation (every eval_frequency + steps). If None or 0, eval disabled. + n_devices: how many devices to use (if None, default, use all available) + random_seed: the random seed to use; time/os dependent if None (default). + save_graphs: bool, if True, save computation graph to file. + save_backward_graph: bool, if True, save backward graph to file too. + has_weights: bool, whether weights are included in the inputs. + nontrainable_param_map: dict, mapping from model nontrainable parameter + names to control names in PolicySchedule. + mask_id: id to mask out (None by default). + + Returns: + trax.State + """ + # TODO(lukaszkaiser): remove has_weights and mask_id later (configure loss). + trainer = trainer_class(model, loss_fn, optimizer, lr_schedule, inputs, + output_dir, + random_seed=random_seed, n_devices=n_devices, + save_steps=save_steps, has_weights=has_weights, + nontrainable_param_map=nontrainable_param_map, + mask_id=mask_id) + + epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None + if eval_frequency and eval_steps > 0: + epoch_steps = itertools.chain([1, # first epoch only 1 step + eval_frequency - 1], + itertools.repeat(eval_frequency)) + step_log(trainer.step, + "Starting training using %d devices" % trainer.n_devices) + + for epoch_steps in epochs(train_steps, trainer.step, epoch_steps): + trainer.train_epoch(epoch_steps, eval_steps) + + # Update nontrainable parameters with new history + trainer.update_nontrainable_params() + + # Bookkeeping we do at the first step + if trainer.step == 1: + # Print number of parameters + trainer.print_n_params() + + # Save computation graph (single-device only for now) + if (save_graphs and backend.get_name() == "jax"): + trainer.save_computation_graphs(save_backward_graph) + + # Save Gin config + trainer.save_gin() + + step_log(trainer.step, "Training done") + return trainer.state diff --git a/trax/trainer_lib_test.py b/trax/trainer_lib_test.py new file mode 100644 index 000000000..1c05a88bd --- /dev/null +++ b/trax/trainer_lib_test.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""trax test.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import functools +import tempfile +from absl.testing import parameterized + +from jax import test_util # pylint: disable=unused-import +from jax.config import config +from jax.lib import xla_bridge + +import tensorflow as tf +from tensorflow import test +from tensorflow.io import gfile + +from trax import backend +from trax import inputs as inputs_lib +from trax import layers +from trax import learning_rate as lr +from trax import models +from trax import optimizers as trax_opt +from trax import trainer_lib +from trax.backend import numpy as np + + + +def test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)): + """Make trainer_lib.inputs.Inputs.""" + batch_size = 2 * xla_bridge.device_count() + + def input_stream(): + key = backend.random.get_prng(0) + while True: + keys = backend.random.split(key, 4) + key = keys[0] + inputs = backend.random.uniform(keys[1], [batch_size] + list(input_shape)) + targets = backend.random.randint( + keys[2], [batch_size], dtype=np.int32, minval=0, maxval=n_classes) + weights = backend.random.uniform(keys[3], [batch_size]) + if with_weights: + yield inputs, targets, weights + else: + yield inputs, targets + + return inputs_lib.Inputs( + train_stream=input_stream, + train_eval_stream=input_stream, + eval_stream=input_stream, + input_shape=input_shape, + input_dtype=np.float32, + target_shape=(), + target_dtype=np.int32) + + + +BACKENDS = ["jax"] + + +class TraxTest(test.TestCase, parameterized.TestCase): + + @contextlib.contextmanager + def tmp_dir(self): + tmp = tempfile.mkdtemp(dir=self.get_temp_dir()) + yield tmp + gfile.rmtree(tmp) + + # TODO(wangpeng): Remove `skipTest`'s when tf-numpy's `pmap` is in place + + @parameterized.parameters(BACKENDS) + def test_train_eval_predict(self, backend_name): + if xla_bridge.device_count() > 1 and backend_name == "tf": + self.skipTest("tf-numpy backend doesn't support multi-devices yet.") + with backend.use_backend(backend_name), self.tmp_dir() as output_dir: + # Prepare model and inputs + n_classes = 4 + train_steps = 2 + eval_steps = 2 + + # Adds Dropout and BatchNorm to test state handling. + def model_fn(mode="train"): + return layers.Model( + layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), + models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) + + inputs = lambda _: test_inputs(n_classes) + + # Train and evaluate + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + train_steps=train_steps, + eval_steps=eval_steps) + + # Assert total train steps + self.assertEqual(train_steps, state.step) + + # Assert 2 evaluations ran + train_acc = state.history.get("train", "metrics/accuracy") + eval_acc = state.history.get("eval", "metrics/accuracy") + self.assertEqual(len(train_acc), len(eval_acc)) + self.assertLen(eval_acc, 2) + + # Predict with final params + inputs = inputs(1).train_stream() + model = layers.Serial(model_fn()) + model(next(inputs)[0], params=state.opt_state.params) + + @parameterized.parameters(BACKENDS) + def test_train_eval_predict_sm3(self, backend_name): + if xla_bridge.device_count() > 1 and backend_name == "tf": + self.skipTest("tf-numpy backend doesn't support multi-devices yet.") + with backend.use_backend(backend_name), self.tmp_dir() as output_dir: + # Prepare model and inputs + n_classes = 4 + train_steps = 2 + eval_steps = 2 + model_fn = functools.partial( + models.MLP, d_hidden=16, n_output_classes=n_classes) + inputs = lambda _: test_inputs(n_classes) + + # Train and evaluate + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + train_steps=train_steps, + eval_steps=eval_steps, + optimizer=trax_opt.SM3) + + # Assert total train steps + self.assertEqual(train_steps, state.step) + + # Assert 2 evaluations ran + train_acc = state.history.get("train", "metrics/accuracy") + eval_acc = state.history.get("eval", "metrics/accuracy") + self.assertEqual(len(train_acc), len(eval_acc)) + self.assertLen(eval_acc, 2) + + # Predict with final params + inputs = inputs(1).train_stream() + model = layers.Serial(model_fn()) + model(next(inputs)[0], params=state.opt_state.params) + + @parameterized.parameters(BACKENDS) + def test_train_restart(self, backend_name): + if xla_bridge.device_count() > 1 and backend_name == "tf": + self.skipTest("tf-numpy backend doesn't support multi-devices yet.") + with backend.use_backend(backend_name), self.tmp_dir() as output_dir: + # Prepare model and inputs + n_classes = 4 + train_steps = 2 + eval_steps = 2 + model_fn = functools.partial( + models.MLP, d_hidden=16, n_output_classes=n_classes) + inputs = lambda _: test_inputs(n_classes) + + # Train and evaluate + trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + train_steps=train_steps, + eval_steps=eval_steps) + + # Restart training + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + train_steps=(2 * train_steps), + eval_steps=eval_steps) + + # Assert total train steps + self.assertEqual(state.step, 2 * train_steps) + + @parameterized.parameters(BACKENDS) + def test_train_with_weights(self, backend_name): + if xla_bridge.device_count() > 1 and backend_name == "tf": + self.skipTest("tf-numpy backend doesn't support multi-devices yet.") + with backend.use_backend(backend_name), self.tmp_dir() as output_dir: + # Prepare model and inputs + n_classes = 4 + train_steps = 2 + eval_steps = 2 + model_fn = functools.partial( + models.MLP, d_hidden=16, n_output_classes=n_classes) + inputs = lambda _: test_inputs(n_classes, with_weights=True) + + # Train and evaluate + state = trainer_lib.train( + output_dir, + model=model_fn, + inputs=inputs, + train_steps=train_steps, + eval_steps=eval_steps, + has_weights=True) + + # Assert total train steps + self.assertEqual(state.step, train_steps) + + @parameterized.parameters(BACKENDS) + def test_reset_twice(self, backend_name): + if xla_bridge.device_count() > 1 and backend_name == "tf": + self.skipTest("tf-numpy backend doesn't support multi-devices yet.") + with backend.use_backend(backend_name), self.tmp_dir() as output_dir1, \ + self.tmp_dir() as output_dir2: + n_classes = 4 + model_fn = functools.partial( + models.MLP, d_hidden=16, n_output_classes=n_classes) + inputs = lambda _: test_inputs(n_classes) + + trainer = trainer_lib.Trainer( + model=model_fn, + loss_fn=layers.CrossEntropyLossScalar, + optimizer=trax_opt.SM3, + lr_schedule=lr.MultifactorSchedule, + inputs=inputs, + ) + + trainer.reset(output_dir1) + trainer.evaluate(1) + trainer.reset(output_dir2) + trainer.evaluate(1) + + + +class EpochsTest(test.TestCase): + + def test_cuts_epoch_when_total_steps_reached(self): + epoch_steps = trainer_lib.epochs( + total_steps=5, steps_to_skip=0, epoch_steps=[1, 2, 3]) + self.assertEqual(list(epoch_steps), [1, 2, 2]) + + def test_skips_full_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=2, epoch_steps=[2, 2]) + self.assertEqual(list(epoch_steps), [2]) + + def test_skips_part_of_epoch(self): + epoch_steps = trainer_lib.epochs( + total_steps=4, steps_to_skip=1, epoch_steps=[2, 2]) + self.assertEqual(list(epoch_steps), [1, 2]) + + +if __name__ == "__main__": + config.config_with_absl() + test.main() diff --git a/trax/utils.py b/trax/utils.py new file mode 100644 index 000000000..e59c8a115 --- /dev/null +++ b/trax/utils.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright 2019 The Trax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle +import sys + +import cloudpickle +import numpy as np + + +def get_pickle_module(): + """Returns the appropriate pickle module based on Python version.""" + # TODO(gilmer, lukaszkaiser): figure out how to use cloudpickle in python3. + # Currently the code throws an error when run in python3. + if sys.version_info[0] < 3: + return cloudpickle + else: + return pickle + + +def gumbel_sample(log_probs): + """Gumbel sampling from a categorical distribution.""" + u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) + g = -np.log(-np.log(u)) + return np.argmax(log_probs + g, axis=-1)