From 8a65ba46c4143efc0b4f7670f23a77cc99c018a9 Mon Sep 17 00:00:00 2001 From: Matteo Manica Date: Fri, 11 Feb 2022 23:18:07 +0100 Subject: [PATCH] feat: initial release for GT4SD project. Signed-off-by: Matteo Manica --- .gitignore | 12 + CITATION.cff | 10 + CONTRIBUTING.md | 125 ++ README.md | 166 ++- conda.yml | 8 + dev_requirements.txt | 14 + docs/Makefile | 24 + docs/_static/gt4sd_logo.png | Bin 0 -> 145035 bytes docs/_templates/module.rst | 75 ++ docs/_templates/package.rst | 115 ++ docs/app.py | 85 ++ docs/conf.py | 119 ++ docs/index.md | 21 + docs/make.bat | 35 + docs/source/gt4sd_algorithm_addition_md.md | 358 ++++++ docs/source/gt4sd_inference_usage_md.md | 97 ++ extras_requirements.txt | 3 + pyproject.toml | 13 + requirements.txt | 29 + setup.cfg | 156 +++ setup.py | 6 + src/gt4sd/__init__.py | 4 + src/gt4sd/algorithms/__init__.py | 43 + .../conditional_generation/__init__.py | 0 .../guacamol/__init__.py | 27 + .../conditional_generation/guacamol/core.py | 673 ++++++++++ .../guacamol/implementation/__init__.py | 599 +++++++++ .../guacamol/implementation/graph_ga.py | 56 + .../guacamol/implementation/graph_mcts.py | 65 + .../guacamol/implementation/moses_aae.py | 46 + .../guacamol/implementation/moses_organ.py | 48 + .../guacamol/implementation/moses_vae.py | 51 + .../guacamol/implementation/smiles_ga.py | 56 + .../guacamol/implementation/smiles_lstm_hc.py | 70 ++ .../implementation/smiles_lstm_ppo.py | 54 + .../key_bert/__init__.py | 5 + .../conditional_generation/key_bert/core.py | 180 +++ .../key_bert/implementation.py | 97 ++ .../paccmann_rl/__init__.py | 13 + .../paccmann_rl/core.py | 269 ++++ .../paccmann_rl/implementation.py | 382 ++++++ .../regression_transformer/__init__.py | 12 + .../regression_transformer/core.py | 352 ++++++ .../regression_transformer/implementation.py | 589 +++++++++ .../reinvent/__init__.py | 5 + .../conditional_generation/reinvent/core.py | 123 ++ .../reinvent/implementation.py | 114 ++ .../reinvent/reinvent_core/LICENSE | 201 +++ .../reinvent/reinvent_core/README.md | 7 + .../reinvent/reinvent_core/__init__.py | 0 .../reinvent/reinvent_core/core.py | 125 ++ .../template/__init__.py | 8 + .../conditional_generation/template/core.py | 105 ++ .../template/implementation.py | 35 + .../conditional_generation/tests/__init__.py | 0 .../tests/test_guacamol.py | 230 ++++ .../tests/test_key_bert.py | 105 ++ .../tests/test_moses.py | 127 ++ .../tests/test_paccmann_rl.py | 162 +++ .../tests/test_regression_transformer.py | 190 +++ .../tests/test_reinvent.py | 104 ++ .../controlled_sampling/__init__.py | 0 .../advanced_manufacturing/__init__.py | 5 + .../advanced_manufacturing/core.py | 135 ++ .../implementation/__init__.py | 0 .../implementation/core.py | 495 ++++++++ .../implementation/nccr/__init__.py | 7 + .../implementation/nccr/core.py | 222 ++++ .../class_controlled_sampling/__init__.py | 14 + .../class_controlled_sampling/core.py | 197 +++ .../implementation.py | 246 ++++ .../paccmann_gp/__init__.py | 5 + .../controlled_sampling/paccmann_gp/core.py | 253 ++++ .../paccmann_gp/implementation.py | 244 ++++ .../controlled_sampling/tests/__init__.py | 0 .../tests/test_advanced_manufacturing.py | 155 +++ .../tests/test_class_controlled_sampling.py | 200 +++ .../tests/test_paccmann_gp.py | 132 ++ src/gt4sd/algorithms/core.py | 493 ++++++++ src/gt4sd/algorithms/generation/__init__.py | 0 .../generation/hugging_face/__init__.py | 21 + .../generation/hugging_face/core.py | 314 +++++ .../generation/hugging_face/implementation.py | 270 ++++ .../algorithms/generation/molgx/__init__.py | 17 + src/gt4sd/algorithms/generation/molgx/core.py | 227 ++++ .../generation/molgx/implementation.py | 239 ++++ .../generation/polymer_blocks/__init__.py | 5 + .../generation/polymer_blocks/core.py | 122 ++ .../polymer_blocks/implementation.py | 126 ++ .../algorithms/generation/tests/__init__.py | 0 .../generation/tests/test_hugging_face.py | 193 +++ .../algorithms/generation/tests/test_molgx.py | 102 ++ .../generation/tests/test_polymer_blocks.py | 103 ++ src/gt4sd/algorithms/prediction/__init__.py | 0 src/gt4sd/algorithms/prediction/core.py | 11 + src/gt4sd/algorithms/prediction/paccmann.py | 0 .../algorithms/prediction/tests/__init__.py | 0 .../prediction/tests/test_topics_zero_shot.py | 101 ++ .../prediction/topics_zero_shot/__init__.py | 5 + .../prediction/topics_zero_shot/core.py | 111 ++ .../topics_zero_shot/implementation.py | 79 ++ src/gt4sd/algorithms/registry.py | 348 +++++ src/gt4sd/algorithms/tests/__init__.py | 0 src/gt4sd/algorithms/tests/test_config.py | 47 + src/gt4sd/algorithms/tests/test_registry.py | 138 ++ src/gt4sd/cli/__init__.py | 1 + src/gt4sd/cli/argument_parser.py | 150 +++ src/gt4sd/cli/hf_to_st_converter.py | 100 ++ .../cli/load_arguments_from_dataclass.py | 87 ++ src/gt4sd/cli/pl_to_hf_converter.py | 96 ++ src/gt4sd/cli/trainer.py | 148 +++ src/gt4sd/configuration.py | 123 ++ src/gt4sd/conftest.py | 15 + src/gt4sd/domains/__init__.py | 0 src/gt4sd/domains/core.py | 7 + src/gt4sd/domains/materials/__init__.py | 58 + .../domains/materials/protein_encoding.py | 219 ++++ src/gt4sd/domains/materials/scorer.py | 339 +++++ src/gt4sd/exceptions.py | 81 ++ src/gt4sd/extras/__init__.py | 11 + src/gt4sd/frameworks/__init__.py | 0 src/gt4sd/frameworks/enzeptional/__init__.py | 6 + .../frameworks/enzeptional/optimization.py | 422 +++++++ .../frameworks/enzeptional/processing.py | 234 ++++ .../frameworks/enzeptional/tests/__init__.py | 0 .../enzeptional/tests/test_processing.py | 22 + src/gt4sd/frameworks/granular/__init__.py | 1 + .../granular/arg_parser/__init__.py | 1 + .../frameworks/granular/arg_parser/parser.py | 96 ++ .../frameworks/granular/arg_parser/utils.py | 51 + .../granular/dataloader/__init__.py | 1 + .../granular/dataloader/data_module.py | 200 +++ .../frameworks/granular/dataloader/dataset.py | 522 ++++++++ .../frameworks/granular/dataloader/sampler.py | 69 + src/gt4sd/frameworks/granular/ml/__init__.py | 1 + .../frameworks/granular/ml/models/__init__.py | 29 + .../granular/ml/models/activation.py | 12 + .../granular/ml/models/base_model.py | 182 +++ .../frameworks/granular/ml/models/loss.py | 204 +++ .../ml/models/mlp_auto_encoder/__init__.py | 3 + .../ml/models/mlp_auto_encoder/core.py | 223 ++++ .../ml/models/mlp_predictor/__init__.py | 3 + .../granular/ml/models/mlp_predictor/core.py | 180 +++ .../granular/ml/models/model_builder.py | 128 ++ .../frameworks/granular/ml/models/module.py | 1115 +++++++++++++++++ .../ml/models/no_encoding/__init__.py | 3 + .../granular/ml/models/no_encoding/core.py | 180 +++ .../frameworks/granular/ml/models/utils.py | 37 + .../granular/ml/models/vae_mlp/__init__.py | 3 + .../granular/ml/models/vae_mlp/core.py | 285 +++++ .../granular/ml/models/vae_rnn/__init__.py | 3 + .../granular/ml/models/vae_rnn/core.py | 339 +++++ .../granular/ml/models/vae_trans/__init__.py | 3 + .../granular/ml/models/vae_trans/core.py | 405 ++++++ src/gt4sd/frameworks/granular/ml/module.py | 262 ++++ .../frameworks/granular/tests/__init__.py | 0 .../granular/tests/test_tokenizer.py | 161 +++ .../frameworks/granular/tokenizer/__init__.py | 9 + .../granular/tokenizer/tokenizer.py | 509 ++++++++ src/gt4sd/frameworks/torch/__init__.py | 47 + src/gt4sd/frameworks/torch/vae.py | 17 + src/gt4sd/py.typed | 0 src/gt4sd/s3.py | 165 +++ src/gt4sd/tests/__init__.py | 0 src/gt4sd/tests/test_configuration.py | 35 + src/gt4sd/tests/test_exceptions.py | 37 + src/gt4sd/tests/test_s3.py | 94 ++ src/gt4sd/tests/utils.py | 24 + src/gt4sd/training_pipelines/__init__.py | 95 ++ src/gt4sd/training_pipelines/core.py | 18 + .../mock_training_pipeline.json | 17 + .../training_pipelines/paccmann/__init__.py | 0 src/gt4sd/training_pipelines/paccmann/core.py | 105 ++ .../paccmann/vae/__init__.py | 0 .../training_pipelines/paccmann/vae/core.py | 309 +++++ .../pytorch_lightning/__init__.py | 0 .../pytorch_lightning/core.py | 178 +++ .../pytorch_lightning/granular/__init__.py | 0 .../pytorch_lightning/granular/core.py | 153 +++ .../language_modeling/__init__.py | 0 .../language_modeling/core.py | 248 ++++ .../language_modeling/lm_datasets.py | 344 +++++ .../language_modeling/models.py | 271 ++++ .../terminator_training.json | 170 +++ .../training_pipelines/tests/__init__.py | 0 .../training_pipelines/tests/lm_example.jsonl | 8 + .../training_pipelines/tests/molecules.smi | 64 + .../tests/test_argument_parser.py | 188 +++ .../tests/test_training_language_modeling.py | 214 ++++ .../tests/test_training_paccmann_vae.py | 70 ++ .../tests/test_training_pipelines.py | 42 + 191 files changed, 22435 insertions(+), 2 deletions(-) create mode 100644 CITATION.cff create mode 100644 CONTRIBUTING.md create mode 100644 conda.yml create mode 100644 dev_requirements.txt create mode 100644 docs/Makefile create mode 100644 docs/_static/gt4sd_logo.png create mode 100644 docs/_templates/module.rst create mode 100644 docs/_templates/package.rst create mode 100644 docs/app.py create mode 100644 docs/conf.py create mode 100644 docs/index.md create mode 100644 docs/make.bat create mode 100644 docs/source/gt4sd_algorithm_addition_md.md create mode 100644 docs/source/gt4sd_inference_usage_md.md create mode 100644 extras_requirements.txt create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/gt4sd/__init__.py create mode 100644 src/gt4sd/algorithms/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_ga.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_mcts.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_aae.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_organ.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_vae.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_ga.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_hc.py create mode 100644 src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_ppo.py create mode 100644 src/gt4sd/algorithms/conditional_generation/key_bert/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/key_bert/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/key_bert/implementation.py create mode 100644 src/gt4sd/algorithms/conditional_generation/paccmann_rl/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/paccmann_rl/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/paccmann_rl/implementation.py create mode 100644 src/gt4sd/algorithms/conditional_generation/regression_transformer/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/regression_transformer/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/regression_transformer/implementation.py create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/implementation.py create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/LICENSE create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/README.md create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/template/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/template/core.py create mode 100644 src/gt4sd/algorithms/conditional_generation/template/implementation.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/__init__.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_guacamol.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_key_bert.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_moses.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_paccmann_rl.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_regression_transformer.py create mode 100644 src/gt4sd/algorithms/conditional_generation/tests/test_reinvent.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/core.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/core.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/core.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/core.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/implementation.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/paccmann_gp/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/paccmann_gp/core.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/paccmann_gp/implementation.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/tests/__init__.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/tests/test_advanced_manufacturing.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/tests/test_class_controlled_sampling.py create mode 100644 src/gt4sd/algorithms/controlled_sampling/tests/test_paccmann_gp.py create mode 100644 src/gt4sd/algorithms/core.py create mode 100644 src/gt4sd/algorithms/generation/__init__.py create mode 100644 src/gt4sd/algorithms/generation/hugging_face/__init__.py create mode 100644 src/gt4sd/algorithms/generation/hugging_face/core.py create mode 100644 src/gt4sd/algorithms/generation/hugging_face/implementation.py create mode 100644 src/gt4sd/algorithms/generation/molgx/__init__.py create mode 100644 src/gt4sd/algorithms/generation/molgx/core.py create mode 100644 src/gt4sd/algorithms/generation/molgx/implementation.py create mode 100644 src/gt4sd/algorithms/generation/polymer_blocks/__init__.py create mode 100644 src/gt4sd/algorithms/generation/polymer_blocks/core.py create mode 100644 src/gt4sd/algorithms/generation/polymer_blocks/implementation.py create mode 100644 src/gt4sd/algorithms/generation/tests/__init__.py create mode 100644 src/gt4sd/algorithms/generation/tests/test_hugging_face.py create mode 100644 src/gt4sd/algorithms/generation/tests/test_molgx.py create mode 100644 src/gt4sd/algorithms/generation/tests/test_polymer_blocks.py create mode 100644 src/gt4sd/algorithms/prediction/__init__.py create mode 100644 src/gt4sd/algorithms/prediction/core.py create mode 100644 src/gt4sd/algorithms/prediction/paccmann.py create mode 100644 src/gt4sd/algorithms/prediction/tests/__init__.py create mode 100644 src/gt4sd/algorithms/prediction/tests/test_topics_zero_shot.py create mode 100644 src/gt4sd/algorithms/prediction/topics_zero_shot/__init__.py create mode 100644 src/gt4sd/algorithms/prediction/topics_zero_shot/core.py create mode 100644 src/gt4sd/algorithms/prediction/topics_zero_shot/implementation.py create mode 100644 src/gt4sd/algorithms/registry.py create mode 100644 src/gt4sd/algorithms/tests/__init__.py create mode 100644 src/gt4sd/algorithms/tests/test_config.py create mode 100644 src/gt4sd/algorithms/tests/test_registry.py create mode 100644 src/gt4sd/cli/__init__.py create mode 100644 src/gt4sd/cli/argument_parser.py create mode 100755 src/gt4sd/cli/hf_to_st_converter.py create mode 100644 src/gt4sd/cli/load_arguments_from_dataclass.py create mode 100755 src/gt4sd/cli/pl_to_hf_converter.py create mode 100755 src/gt4sd/cli/trainer.py create mode 100644 src/gt4sd/configuration.py create mode 100644 src/gt4sd/conftest.py create mode 100644 src/gt4sd/domains/__init__.py create mode 100644 src/gt4sd/domains/core.py create mode 100644 src/gt4sd/domains/materials/__init__.py create mode 100644 src/gt4sd/domains/materials/protein_encoding.py create mode 100644 src/gt4sd/domains/materials/scorer.py create mode 100644 src/gt4sd/exceptions.py create mode 100644 src/gt4sd/extras/__init__.py create mode 100644 src/gt4sd/frameworks/__init__.py create mode 100644 src/gt4sd/frameworks/enzeptional/__init__.py create mode 100644 src/gt4sd/frameworks/enzeptional/optimization.py create mode 100644 src/gt4sd/frameworks/enzeptional/processing.py create mode 100644 src/gt4sd/frameworks/enzeptional/tests/__init__.py create mode 100644 src/gt4sd/frameworks/enzeptional/tests/test_processing.py create mode 100644 src/gt4sd/frameworks/granular/__init__.py create mode 100644 src/gt4sd/frameworks/granular/arg_parser/__init__.py create mode 100644 src/gt4sd/frameworks/granular/arg_parser/parser.py create mode 100644 src/gt4sd/frameworks/granular/arg_parser/utils.py create mode 100644 src/gt4sd/frameworks/granular/dataloader/__init__.py create mode 100644 src/gt4sd/frameworks/granular/dataloader/data_module.py create mode 100644 src/gt4sd/frameworks/granular/dataloader/dataset.py create mode 100644 src/gt4sd/frameworks/granular/dataloader/sampler.py create mode 100644 src/gt4sd/frameworks/granular/ml/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/activation.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/base_model.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/loss.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/mlp_predictor/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/mlp_predictor/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/model_builder.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/module.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/no_encoding/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/no_encoding/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/utils.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_mlp/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_mlp/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_rnn/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_rnn/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_trans/__init__.py create mode 100644 src/gt4sd/frameworks/granular/ml/models/vae_trans/core.py create mode 100644 src/gt4sd/frameworks/granular/ml/module.py create mode 100644 src/gt4sd/frameworks/granular/tests/__init__.py create mode 100644 src/gt4sd/frameworks/granular/tests/test_tokenizer.py create mode 100644 src/gt4sd/frameworks/granular/tokenizer/__init__.py create mode 100644 src/gt4sd/frameworks/granular/tokenizer/tokenizer.py create mode 100644 src/gt4sd/frameworks/torch/__init__.py create mode 100644 src/gt4sd/frameworks/torch/vae.py create mode 100644 src/gt4sd/py.typed create mode 100644 src/gt4sd/s3.py create mode 100644 src/gt4sd/tests/__init__.py create mode 100644 src/gt4sd/tests/test_configuration.py create mode 100644 src/gt4sd/tests/test_exceptions.py create mode 100644 src/gt4sd/tests/test_s3.py create mode 100644 src/gt4sd/tests/utils.py create mode 100644 src/gt4sd/training_pipelines/__init__.py create mode 100644 src/gt4sd/training_pipelines/core.py create mode 100644 src/gt4sd/training_pipelines/mock_training_pipeline.json create mode 100644 src/gt4sd/training_pipelines/paccmann/__init__.py create mode 100644 src/gt4sd/training_pipelines/paccmann/core.py create mode 100644 src/gt4sd/training_pipelines/paccmann/vae/__init__.py create mode 100644 src/gt4sd/training_pipelines/paccmann/vae/core.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/__init__.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/core.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/granular/__init__.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/granular/core.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/__init__.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/core.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/lm_datasets.py create mode 100644 src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/models.py create mode 100644 src/gt4sd/training_pipelines/terminator_training.json create mode 100644 src/gt4sd/training_pipelines/tests/__init__.py create mode 100644 src/gt4sd/training_pipelines/tests/lm_example.jsonl create mode 100644 src/gt4sd/training_pipelines/tests/molecules.smi create mode 100644 src/gt4sd/training_pipelines/tests/test_argument_parser.py create mode 100644 src/gt4sd/training_pipelines/tests/test_training_language_modeling.py create mode 100644 src/gt4sd/training_pipelines/tests/test_training_paccmann_vae.py create mode 100644 src/gt4sd/training_pipelines/tests/test_training_pipelines.py diff --git a/.gitignore b/.gitignore index b6e47617d..8db12b174 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/api/* # PyBuilder target/ @@ -127,3 +128,14 @@ dmypy.json # Pyre type checker .pyre/ + +# Visual Studio Code settings +.vscode/ + +# PyCharm settings +.idea/ + +# custom +logs +test +.DS_Store diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..e35888f7b --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,10 @@ +cff-version: 1.2.0 +message: "If you use GT4SD, please consider citing as below." +authors: + - family-names: Team + given-names: GT4SD +title: "GT4SD (Generative Toolkit for Scientific Discovery)" +version: 0.22.0 +url: "https://github.com/GT4SD/gt4sd-core" +# doi: TBD +date-released: 2022-02-11 \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..c9a4b231c --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,125 @@ +# Contributing + + + +## Contributing to GT4SD codebase + +If you would like to contribute to the package, we recommend the following development setup. + +1. Create a copy of the [repository](https://github.com/GT4SD/gt4sd-core) via the ‘Fork’ button. + +2. Clone the gt4sd-core repository: + + ```sh + git clone git@github.com:${GH_ACCOUNT_OR_ORG}/gt4sd-core.git + ``` + +3. Create a dedicated branch: + + ```sh + cd gt4sd-core + git checkout -b a-super-nice-feature-we-all-need + ``` + +4. Create and activate a dedicated conda environment: + + ```sh + conda env create -f conda.yml + conda activate gt4sd + ``` + +5. Install `gt4sd` in editable mode: + + ```sh + pip install -e. + ``` + +6. Implement your changes and once you are ready run the tests: + + ```sh + python -m pytest -sv + ``` + + And the style checks: + + ```sh + # blacking and sorting imports + python -m black src/gt4sd + python -m isort src/gt4sd + # checking flake8 and mypy + python -m flake8 --disable-noqa --per-file-ignores="__init__.py:F401" src/gt4sd + python -m mypy src/gt4sd + ``` + +7. Once the tests and checks passes, but most importantly you are happy with the implemented feature commit your changes. + + ```sh + # add the changes + git add + # commit them + git commit -s -m "feat: implementing super nice feature." -m "A feature we all need." + # check upstream changes + git fetch upstream + git rebase upstream/main + # push changes to your fork + git push -u origin a-super-nice-feature-we-all-need + ``` + +8. Open a PR via the "Pull request" button, the maintainers will be happy to review it. + +## Contributing to GT4SD documentation + +We recommend the "Python Docstring Generator" extension in VSCode. + +However, the types should not be duplicated. +The sphinx documentation will pick it up from [type annotations](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html#type-annotations). +Unfortunately, a custom template is required to not add any types at all. + +Its settings are: + +```json + "autoDocstring.docstringFormat": "google", + "autoDocstring.startOnNewLine": false, + "autoDocstring.customTemplatePath": "/absolute_path_to/.google_pep484.mustache" +``` + +where the last line would point to the custom template file (e.g. in your user home) +with the following content: (just placeholders for types are removed): + +```tpl +{{! Google Docstring Template }} +{{summaryPlaceholder}} + +{{extendedSummaryPlaceholder}} + +{{#parametersExist}} +Args: +{{#args}} + {{var}}: {{descriptionPlaceholder}} +{{/args}} +{{#kwargs}} + {{var}}: {{descriptionPlaceholder}}. Defaults to {{&default}}. +{{/kwargs}} +{{/parametersExist}} + +{{#exceptionsExist}} +Raises: +{{#exceptions}} + {{type}}: {{descriptionPlaceholder}} +{{/exceptions}} +{{/exceptionsExist}} + +{{#returnsExist}} +Returns: +{{#returns}} + {{descriptionPlaceholder}} +{{/returns}} +{{/returnsExist}} + +{{#yieldsExist}} +Yields: +{{#yields}} + {{descriptionPlaceholder}} +{{/yields}} +{{/yieldsExist}} +``` diff --git a/README.md b/README.md index 2f87b636d..b5afb2f1f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,164 @@ -# gt4sd-core -GT4SD (Generative Toolkit for Scientific Discovery) an open-source platform to accelerate hypothesis generation in the scientific discovery process. +# GT4SD (Generative Toolkit for Scientific Discovery) + + + + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + + +[![Contributions](https://img.shields.io/badge/contributions-welcome-blue)]() + + +logo + +The GT4SD (Generative Toolkit for Scientific Discovery) is an open-source platform to accelerate hypothesis generation in the scientific discovery process. It provides a library for making state-of-the-art generative AI models easier to use. + + + + +## Installation + +### pip + + + + +You can install `gt4sd` directly from GitHub: + +```sh +pip install git+https://github.com/GT4SD/gt4sd-core +``` + +### Development setup & installation + +If you would like to contribute to the package, we recommend the following development setup: +Clone the gt4sd-core repository: + +```sh +git clone git@github.com:GT4SD/gt4sd-core.git +cd gt4ds-core +conda env create -f conda.yml +conda activate gt4sd +pip install -e . +``` + +Learn more in [CONTRIBUTING.md](./CONTRIBUTING.md) + +## Supported packages + +Beyond implementing various generative modeling inference and training pipelines GT4SD is designed to provide a high-level API that implement an harmonized interface for several existing packages: + +- [GuacaMol](https://github.com/BenevolentAI/guacamol): inference pipelines for the baselines models. +- [MOSES](https://github.com/molecularsets/moses): inference pipelines for the baselines models. +- [TAPE](https://github.com/songlab-cal/tape): encoder modules compatible with the protein language models. +- [PaccMann](https://github.com/PaccMann/): inference pipelines for all algorithms of the PaccMann family as well as traiing pipelines for the generative VAEs. +- [transformers](https://huggingface.co/transformers): training and inference pipelines for generative models from the [HuggingFace Models](https://huggingface.co/models) + +## Using GT4SD + +### Running inference pipelines + +Running an algorithm is as easy as typing: + +```python +from gt4sd.algorithms.conditional_generation.paccmann_rl.core import ( + PaccMannRLProteinBasedGenerator, PaccMannRL +) +target = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT' +# algorithm configuration with default parameters +configuration = PaccMannRLProteinBasedGenerator() +# instantiate the algorithm for sampling +algorithm = PaccMannRL(configuration=configuration, target=target) +items = list(algorithm.sample(10)) +print(items) +``` + +Or you can use the `ApplicationRegistry` to run an algorithm instance using a +serialized representation of the algorithm: + +```python +from gt4sd.algorithms.registry import ApplicationsRegistry +target = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT' +algorithm = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type='conditional_generation', + domain='materials', + algorithm_name='PaccMannRL', + algorithm_application='PaccMannRLProteinBasedGenerator', + generated_length=32, + # include additional configuration parameters as **kwargs +) +items = list(algorithm.sample(10)) +print(items) +``` + +### Running training pipelines via the CLI command + +GT4SD provides a trainer client based on the `gt4sd-trainer` CLI command. The trainer currently supports training pipelines for language modeling (`language-modeling-trainer`), PaccMann (`paccmann-vae-trainer`) and Granular (`granular-trainer`, multimodal compositional autoencoders). + +```console +$ gt4sd-trainer --help +usage: gt4sd-trainer [-h] --training_pipeline_name TRAINING_PIPELINE_NAME + [--configuration_file CONFIGURATION_FILE] + +optional arguments: + -h, --help show this help message and exit + --training_pipeline_name TRAINING_PIPELINE_NAME + Training type of the converted model, supported types: + granular-trainer, language-modeling-trainer, paccmann- + vae-trainer. (default: None) + --configuration_file CONFIGURATION_FILE + Configuration file for the trainining. It can be used + to completely by-pass pipeline specific arguments. + (default: None) +``` + +To launch a training you have two options. + +You can either specify the training pipeline and the path of a configuration file that contains the needed training parameters: + +```sh +gt4sd-trainer --training_pipeline_name ${TRAINING_PIPELINE_NAME} --configuration_file ${CONFIGURATION_FILE} +``` + +Or you can provide directly the needed parameters as argumentsL + +```sh +gt4sd-trainer --training_pipeline_name language-modeling-trainer --type mlm --model_name_or_path mlm --training_file /pah/to/train_file.jsonl --validation_file /path/to/valid_file.jsonl +``` + +To get more info on a specific training pipeleins argument simply type: + +```sh +gt4sd-trainer --training_pipeline_name ${TRAINING_PIPELINE_NAME} --help +``` + + + + + +## References + +If you use `gt4sd` in your projects, please consider citing the following: + +```bib +@misc{GT4SD, + author = {GT4SD Team}, + title = {GT4SD (Generative Toolkit for Scientific Discovery)}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/GT4SD/gt4sd-core}}, + commit = {main} +} +``` + +## License + +The `gt4sd` codebase is under MIT license. +For individual model usage, please refer to the model licenses found in the original packages. diff --git a/conda.yml b/conda.yml new file mode 100644 index 000000000..e1fea253d --- /dev/null +++ b/conda.yml @@ -0,0 +1,8 @@ +name: gt4sd +dependencies: + - python>=3.7,<3.8 + - pip>=19.1,<20.3 + - pip: + - -r requirements.txt + # development + - -r dev_requirements.txt diff --git a/dev_requirements.txt b/dev_requirements.txt new file mode 100644 index 000000000..58e669893 --- /dev/null +++ b/dev_requirements.txt @@ -0,0 +1,14 @@ +flake8==3.8.4 +mypy==0.800 +pytest==6.1.1 +pytest-cov==2.10.1 +black==20.8b1 +isort==5.7.0 +sphinx==3.4.3 +sphinx-autodoc-typehints==1.11.1 +better-apidoc==0.3.1 +sphinx_rtd_theme==0.5.1 +myst-parser==0.13.3 +flask==1.1.2 +flask_login==0.5.0 +docutils==0.17.1 \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..cda001c89 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,24 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +clean: + @-rm -rf $(BUILDDIR)/* + @-rm -rf api/*.rst + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/gt4sd_logo.png b/docs/_static/gt4sd_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..086822e82e738e89db1b0b4bd78e039eba6e470d GIT binary patch literal 145035 zcmeFacTiNz_b)oi@pz1gfMhU(q97T`9z;|?P(VbIf&__<>Hql~e3B*l>;m|)>%nDpOB71;68yJ~=9M)3kikmkmX)%FftBsuhx#a6 zTU)OC<|dYUcOU3;Sv)ihoe*P0p$?<2T)d!S7dF=Uh3c^s>aQ<>ya{qoh5vb${EytR z=dCoSo7cK)Po4|l!I*#<$;Mx=E!fEZ1iwy5=lt{AC-S_-%=BvP zo9TX=sjq^if@j`&BPY_`$Fe3Z#l!`2hcp&-41Z!D{-h**0)Fco{QpkRpBWIA{{JA^ zKb-u-Niy^NqfrEA{NckNK9HFO>W@DB(T6|!fcoDE9sdLvh~E4GkUs$O2SEM+$fgeb z0a5}Z{_x@d9zIMNOxC!JCeMiayRuZ+jGeI&WIwR;+rMq@b$dHejYr>{s%q%TU{M{p zJnmY*k{BFg%;lg?oA5R|`rlt)X?9yRoTE4{_2A!wFZ<5ecv7>oHg zJ-`&v=2H0nSmaEY-YBk9cIYE0dfore(1P*jQF4eyJ^#-_*Ly0i?Y~{|X9^NK{_x@t zF8~Mr*a%4t{_x@tF8~Mr_z{vC{NcqPUH}eA{fV`<^xzLKwpj5;FG#HT-{6J+0~--y zFTkJlhK!Z}XL1wu+p~~{JL_*F63VP{4-SM}l@Bu1Gst2LY)q%3;!%z8o0gxWE>`~7 ztZN#*n&4`2f9<7_ddq|AOpN=9v}AmT2EBscTz!+0LJR={Gb5jhzYV#98>1u z7h!o~nd*Uw)ibgZG2o08+kt-f=A(+ojo9a!Sm)v}yuo}1O>e#J#u zS=QB6 z@LQiq;!TtKR!LSvF5j|{q8GyxO#T7f;MYAd==JGaDoJ0<%HNAAhwM?7Jr+IrcHwnX zbGf^~S+&QQeVj2kp=eB}`E6QGjj)!j-kp-w!wAuT;gtE%iMGPm=J#XZ#36fVYhPcA z5mP?pz&az9X>;P?_$XVzg1^>c*~)m5hOUOJhG}opAq~;ndMmM6Z7uaDjR2DgqtBhd zaqc@=H~b=sUI;szbAWlv6D&)?5*&H8l(LzFvlmk$8~QE+KIl&VoN3k6$t>W>s+RHX z6wEB2_s?oW{}{QIZmfPmZX0phTVqHmW{Tl&*;uSmtrYX&o0C8JC$`=eie_G&GVnB@gd(p>V$TnE&!sey&k=#2c05!~$3 zy9b^}XPSXF-`IAW>{87cJJkT!O zA%0+I{?)XFZ%f0g@iYMFz+Bfu-n{iI^r)@qDIGd?G#Z%PM2^0K(RNsU$j580(Ueoc<-mm5LJ}mwpp@^#=4$0=P;EZft#bz2_7-`hgP}a)$ z26Q=V7r z*%sPbvtoc%ITJN|FEUH05#~>DTp)m z$sK@}@Ss~IDO1=Ed!qo|XFAs7D(JB6?i2A>SBZN?EH2a50So>`C%lue>8@uDnCxIy z6n8QKa1D@!I1wr`@BtNItO{RxbDQI$Q^EW-#W9x+Q{QW^z(iHjjeJ3Ew(q&;S$_FM zuC6(z&CJi?^C`O-dskY_j)FEw!f;~>LuN@GYh~vw0#g(eWQzi@@Q(QF%;DzSXZPC# zVehD&S|4d&L-&r4;u1V8eAvYDPP8Q7zKO7&FwHt0fj!>0s5t)3yy-)s^+C8!)jLvb zMO(~j0kgPfVc|0=W+!(>r$veEwto>@rc(lLdLzvT-w+z7yRaSr{FdhI@jFyWH%-ze zFYAnnHqf~)RW;aQ*Q@bsve{+Xw&vdMALP6vb`p2-$q3QUMb~f43|%%roUiP(c7Onefj0}FAM66{fk>wU`m0u(An zCnLOESDF*RHcmV2FBuZ#)%K@b2=j#h{7ix?C|SD?x~_JMVh61`y(6}~&L@t45KA$3 zh&@2|AA$H<$7^MM!&dc{s0P*3nQe{sg7a1pe%5By_}S9zayIeYV_CKt zwBq}STeh0kps_X2w*kIG+<}0w>r>>)>?VJ+YOMG`sm`J@mrsOneNxZ5bt7 zCu6mOdC->0-jq+M=wrg)U;O zow=u6R%f;FZ8Zwl)zoHriFl*I)F6~$J5UOj+IMY)|Qkc%d z;L=`|kLx_C+bb6W*RX zr>>rNm}=|Otn*OqHxz;hy3hmp> z+zJ}BA@W>ch6|^(G^m8uZH!l!LvrzKY@%l+Z0k?eXQme$Z?W0)9JCuF%fU080Uokv zJg-=`#$DUSVb^Kff_WY%hg~>EhcR3qc$G7aZH#W+7Orf#rDn(Vp^WSP=GXts<1*i2 zw^;rx&8T_4={#!niFfxLy4P|Bmlk}eqBz3#Unl%$K3ttd_YJb!@50-kR~V0mV%yO_ zu=rrQiq|Xx36`iu_QX*}oL9Y^?&8w{yv0x?y_)OltngSL1=;NxrxDJHsVKM7)GBDQ zQu^IG(0STKc?dNn*et+bP8}0m#-|rtUeRM~s_D2Cf*zwPBqKg!H{tdhs*MK}eYFn; z1l&F-nlu7#{}lO_x1F9xU*FfsOx0;@tpcOAF&@qnrBmCazNuR%)Y(d?T!$%DngucG zU=`_Ohssh$+`-=Y(@nYPj_GL7#>54eBU>K`G?eM3$d9YpGB6`5ci26?OLZ(IAWmET zpeudmahq29Gutz46+%q+GU+h7Mid<7+Nj%w&*bc3*tszAT4WD#?I?8yDDdM-VB?Ab z>WsyXcQ2?goY_tjhKViSORw`b@3h1WcU)O!a5#f_p6#{pt+c4_%w7m^x3B9d3$BaV z%j{(l`3gf!qh2t^I?_?iX)$bDLLl)@RF5DcsYBBpqf9t8#o)3A z0nr5tM%6KFzFLv!wc%cN>#cCJM4Jy$ZG$AC3EwY@HU0$Ej`9o5 zyQ8OxFQvm$4~5g~Y&5l_l~Pm47F<7tENE~;U6pM?Qb;m_HJ_h1M0120)!_y9UWy6E zaPn1ReS#wx)SSmMhMUratq&2&En!7AIu?hCH5KEtQSah&(pG8ae_yjLD!fCr`+#%= zm)FqAvjLMX>)q&{_BCbpEqf{;th;s&U6YryhF4E3anld(sK_6e5v=c++F?%^2APiy9oI_aa zs(YnQ=`2qUrDHP=i>@%VkLGxoy9PxRagviA>m@pEKA5 z9=T(o)gNlMxfR}Mt)qcV4USFmXDVJSva8tTLB^cjzA(Bek^DArG$JxX%w0_NV36t@ z0GT_HXEgGUrR4qku-FCI(<)(Y^E-Q@q!4{C_>y-v9~bKpTkb0p3iPhT<%fYK-s8HQ z?(FR6#TLUFzaKycNnDLWA-7$;sC<$spdg2X)y~DGNyq@0iBx*WZFMH0F3A&^T5TrD z^{@mXZgD!?m{M?Bav-WVX&~&BeDbARC zzh37o$xQ}4ierlq*$N7caNf%BkBp3Eclo{cdB;{k1J9==cZPNt9N6n-*Iyy_ke~j_cJX0e-0PHu^vO-{uUyu18N!*}teBB;$a6 z4RlIIGX`8e%xbHB&)%Ff)B*rvl>NC~vumMiHD{hK^`gCu0zk6OaX0``AG4b$1APFV z^F^WDZt`2=2Xn9<$FR-zqrf(a!yfT?MuCFnAH9<-7-$#7g(SheU@i_0+WZl05&)=9 z)=&mNg?2KHiQs`4+!1!Y@dq}r*Y(I?&}$H0xIGT~AtT1)aAh@8L&s}e6nbsqurEV4 z!3Pi*3ckp3;eC9Hi9BHc#j-Gc4Q#K49gnt7ngCF|&GDQa!jYTcUC5Z_*VmyROHf9+ zxhes@Z0+I|aaX(r?Pzm+qkwQJ)f1X)XX&&o%Z({$mBi`^Q=w!= zb%L;six=;u=h9DGU4dW$b}q%o1PuA6MHfTy`utLTzHBg=0QXp;flggcNzCFYI)CqR z{ZUkg-U+$vGD8by2K`W*Sdz}(M$qed6ZBRyk_-&gL&d ziNWXsA_iXNK?~2?9%0^_W)JLlp>~(oBrXFz)x&;BNCMVN9B1F_jey^y&WSI6eVnt> zp8KMD0Dk>3`5g~+Z}CENTS4K;q6fbi0P<}PHv@!28HvDWXG@42i(tt<1C-{u67X96 zzg}{5W^kIXVjW}jJg`kyh&M^`Aqsn2!CFgCb0PE!kN&zlWKo2jOT%z5(FyF}fi7WJ;0!E5M^ENqL=|Q^2_lfO|wps#>=fQPo4)51O+UXMqZQ zFtLGVKived1U;ar?cI<`OD5zs#P4Oq{v zOMUFeKq4`&qu`t89t(+QO3RNuQsmV6<_(B^03d9RJ8X#YS91mX&`+ye>B4)cL#_f- zG?O=TR(2Y%LGw5`sFUAdB@Z+14DNPQE^`N?+hRg!d+>GCJ`>YQ%-n*!e|P3SbomlH zz$xM)J%1xe)(+6C+}4Hf)X7?{NT38pJ0T{m)@8J5bsuQee0Yh3Nq4wuI{Sazrn}r6a2~N_2jrxNLC)A4ZrqW}FE~zT$ zbv>vq;ov%bl?2`eJ0LxGij#YC_#jI*a4h=Y6j&KC8ZEje57bD$qNJ11Y|t-Kr|;?zjAdS zexS4BqIEdXv*#^IGc6U_u4V-QItXJ41M#!XBycDtz+zv>fw_H{Z73<1H)6KVPLmO8 zF{c8_FlRLgy5>)6+9B5u^)mAw8PVe3j+(A&Q#Zlp+bRlb69YFlF(g@ARffB6?bX{0 zfGRaDy`^kUx=@EGYd0qZ!3sVMo#YB&>sm3s$D0dfWpTf%N&xRXsX-+1@Ogwu0qd(H zl{vq@+<#GRzsQC%V!EWr?KA3ejhD3SQ)JUZa}7$0HW zVJ2j~RS2SChV=DWz{m(*2y!wY6*mKX$;=F{KKRgw8sfrkkCBMz;>EJ29ZWb4>5~_w z^-cnFmr7>@@e#y~CI;fQO6QQxyeWF1?4&of7%v1j^G=fs3uDW&^P?oRlez$ta<}c- zcLuT+2M@%^prV|A+n5NDOKsE0eaE{#8mt!+n{o=S?vrHkt*mu| znNF@U3#uPsxY<-g)!5xMINpqCo4%?u+}BNI)lPbyB4^WAW*s3URqi+oxxgU>ox5DN zpO->s62zX2JZ~*)^f-3`9no{Sy+)X@W3V`N@5i3u$^$5+9N>-oQ@hV|fkd7ZLrCT8 zL7%nbebEtW9X2tp*RRYJfrogm2?D5`38BEFf!{xnQCp5sdr>gd2S%|~+MiFrz4g4| zb(LNndvQO3kI305EeP6f+tD+njk-YB;5XAXQ@ehE@&jEQNmw?e$c!Zvt`6czOhAD; zY_ZQp*Ty%pP>>}4j6D!W>uL(c%&fckOwC93qTC8zR}Wy99Zyy~4(cQmRHz}`(0JR} z2@R|)Wx=6b&5?sBw^g;?=Ux^)0uq8|h)>?6)9fwOzYJwgAr*{}hwmm7xC!c+=$oy! zas%sEkS1Y+Q94wYD3;|P6bm?AwalFp+fY3zw4fW;`Yb(KtUN9!lFr% zt*;*lks*m0>d>`nu!cmCOrYa9a;Yo;U+CS^-kZuo4G^?wd)i=UF!q_Gt#yMCnxsS@ zR#2zw0?dIt?eK#E9`OF}Njw!qq}1Uhsi;i_%NT%-Quz(9zGEQRS8^JYGYi3x@e&aJ z==PzVC zrN$#~v7=8aD*Uyqq4Vr0G^{x5bC>$6vLTU2Qa!L3;5HMQ(?aY85Ii01lY_TSl_7-h zs2#YWfGW@24+A8{I-H8w7n6WhTNk}};TgEO3->{yAX*$j55h8%r#@^!xRT%4L49?i zHWg|e$7Sf2mfr>bj=A>Vp%!+$`K%y{i`^M_i7PnoFI3cnpflYEZN+#*NqX3S0@<0*`&AN=D}-fnP6v$gK8fr^>X)hC|X8lS0`d4kZaWxNbq4 z1rku0l>z@)#>v+T1=+U`R^UX7Lulsqk;&k>5Tbwk-heQf{O&GENgPB(lwns3Enka|~7~BV@it2HPijp509!gS_m%LE= zVj1?1)*v=UprdVPfUhP!5~LCV67`+1m+$u;;=A7k>KP*Vh?>6uVh`M8>?N2wZADL) zzd|uiS9O;TA{hk0`7sZ$CnO_;RTLsTqCO&N+4Q#z;38g7?SV>Jx7UEX{UD{i#m-Fy zmD)z2FilTQtLwqWt8asv1KUtW+!G8ELWe$fkeTpp#AWwC0#!yN?!B!J!w)xuSKs@w z?kjN@oe>ALIc)|bRH1Ja0lv5B8xNVvhl5aH&n@6BCK$*a9^tafrl2a5W0Gx+rhXH#-(Af@23U^;1uszGg50>RBpB=i)B4|`-UqR_rxOl|$8RR5A`5du zA^$CTo=n9vgdu&G?{z%(!#8ot-GHp|UKdW`N5aG{&wcw9M9$-%Kp1Pj+THYR;6^aG zu)}4!uPHt`Q4#EFbDTwjpMyC|@9w7oSwnR=*anL(;=*T1>Xku{7?wWwHFesU*kI_+ z55I%a%oVj`isPo?_BsbQ4+in7$ZxQ!49D((hz>{Noiw*%bg_!qz zr_Igg3bTM(xb@9ZjVQwH?^9kMU{N1}ZU2k<#vi)>e$OLA-ekn^XkV!{l|7HU3VHI} z!`#{Gl1i&UED3lG5MTPS@3m}f1*|;7$=+3WCgCOwhyKT9XHCmq1@XMQrv*7O-KLy+ z#?PJ{=%UmS&EJ#Qrbo(_N|QKrXzH{@T=ki2W>?0<2U)t9;uYrayp>`_9V0 z6LqJz8(%iYks5!6B#0FxUVM8UR6R6YC67JKZ8`zg)mIhRhvXN42*UWJi0pzvocG~f zSS^IR7%%zOdQ%1j>?92AGaG#aTFg)#upW+mG^7Ivs>3?iOfrk-S$k8gFE9V72#;yq z=DOz4daeGWNriBfNnG*T+nR8Zi>);*>+|*tDdEw29zq)%uSS^ouJ_i>QbrmcJ*sP~ zAl=j?Bx7x+*VH0Vp7Pc`KEcFyqyqzfcqewagX$0+sM}jCJ5pqv$ki+}3#j%rG|xTm z?c+nL3K7cSZUzrWWB8&|z9{`|TIt8ug^Tzl*6IUEkspzWa|VNNVJ0%N*G+t+zR+^-a3`O?NU|aP`SgYa^$`vkJ2CB!N$M*lizm`k|d46b?S_S5#YTD1#R1o9nI@_y)h?N1G$9N1JQT_UJ_`25^xa^Bv@Pti&ZT#l#mxFCiinhXH%AaGm!t`+ccb7QABik20*S z%zJNJbHF?j_tWZ;AG*eZ9WHe!ulJa9Xk_aJ0ForA0;@b}y3+c|9iRAarlk5-CNAIP z+Ob<5MA2PqV_l1Qx*oDbaNZlD!dRT+x-d5kgA7qSVhtB+&<)pjV!q3=(Xj6yrjBEF zRtl>tRYz1j;tED2)qepY!AM7ZtuHGEyRp zea@PYzPkg8xTKMCBmV*p1WsuW2C{BtEN=ub*aEY@wJU=iB>67*VPCshOn}Vw>m%TW zN`zr&0^{RPY!i#6$=~ihOyF~~83Z!JX6RA+TV9)>n-oj12(w>3QlFrQ9yzn_^~Ay2 zk21VlQo_n+quANZd1>NJda~Hy>(>=ctwI=#nd+`RkFw0o@7u!zH%;E*98w@S^%@%B z9X7x4%#*QPZEr*l9i2$ z%0zJ7TU}150)*zeE!gVo+Sdbf4j7qCI2hZ6ZKP~NOSN7j(~#1H}CY^fZq6b-U# z&42py!6^tx7E>!W*2Z%G=PIk~;H31fiR#GSsD0U794=QjHhctFEmBYFgKgro;{%1}I0 zf@ba4sllNr;Mq6Ja8YYd$}szjt3OI&a1@imy!%I*1*{L+r`QH?p?{X>VZT1AC<{on zx2V$4W{ojS_Tu5I)J>u~HuX0=k92uxv$t^KF_~CmA-np_V3HgE(c)_Ow*GfmY_whN zeUI`FPy7P{!bA@{SWSkBcnK!v4u^OgE``YXN?p7t z|2AL3;7kV%g`%Vw$<3@HrK~BzD)pPJF|ii;jpA z2Y=4h?3!=$MM8=MT3VTc(O0F|X-XEhp-4xSS&@=kFIui^Y9iGA9pPf*4GGe#sS9j$8>4F#<=_CEXF5|AUzMYuh6WnKA4R*i9C9B0CEwp0 zxC_z|fm<<>^A;+BRP$ZwAaJP%Ik>pIDGPAl7PDs*eP?PW*jPkq;;ud3c-MQO8@DGo zf(o9Z!xG>s+UdUh8Fe~!L;82!oggTSV8}j`lDC)vC0wE*JU-<~8U7}3yU~78lVgLa zidx&2a2cflhhM%Z~$bp3-YR zqGr-)^mTZOL~qsd0_waNthRsyZh+c1;NTQun6ZTf-UBi(xQ9Izl7`WNn!g)E2q|^tHL-{Gi?v1$C zm_*En(?+$>QI-zv_e_#>aR_t1(c5&sE;-M3oL|*-OXU$ zvE)bF}xlqo8S_8dglCJ*s^)idG-}=ez?caaWN-KZTJ2IPdrWA!+C2 z^s|948n|}lA*m$Mq5dLt&3C%8m!DbieS7wR3dv`n?rPgsOv=DX)C^rXr}5N_QN}S;oax1kptv>ds$&O@B>%Fp7TFZ$=kC!iRk%S;$I#wT1EA zzf7=BB2-)BwUV-7M8c%qVR8Hb$0kR z*|h=CoIv2KJBY1L2y}Di2PJrjetX5%rbv4yoFv;(gJsJ>L7|DE#AO5=#7kIz`Hmjf zH^hKA=0nY+0$U8)qfA;zlLpp8@*_@r4UC=!N0Tdk;HVGkFJW9p&bHQJo(CEez=qe9 zG*xKWoNMdx9~C31DVFEJ`IA{qRJydYfj~gY1{1}kzIC7@L==)9zo1X zjo48VgzX?q`em3wKh1R9A5bVEGC8)653eocBp|j?BUCU5z4k+UJ=rgv_1d}4FA0hn zpv}ecSHy1=#tK^lTKfY{1?d*VhdPdC$gJ%jVU^sdy}&b@x5xL;I|l1tO-JAf2oJVK zu?(VcbaazUNza>NClh&4P9~*41j2gmfRrH3xU>ck)>-$#(qplbz1I7Mturrj)^8cr zJ>qz0A!m;B3HBWm7Xhs(#-8PG2n&IDnRb(ND4L@X z8*>(uzFYvcY92kFnoy?nk?e7}4;vUR2AGF~2-V^)eqsp z8SO&xlQmkwTM3@qb|llH!dIX_5R%HivwMeI`B;m7MlWKDBdN7StS^UC6nud^m)w}1 z&Y2RI1y$FcGG{VsDj_u-cGmk8@O7R$fQpHOa|!Px37gwuuAKwGKGy(493MQGG)GPn zMT=atncI>~07d&?I{^NK38|Q+Bt@94)}PbQEGQF9CRAT+QERs-OIJPz`e5>`Z1kx* z9{cn<%C9cX|Gl*{P(OcY=yNf^7IvR8m=mr$Gd0<90{bie$XnD1cgR7?S z)SISLQhQd0^8@^;cQVuWGV&flDj*c&E%t^B0h3yD5%hJbhGgJ~R5y0qZ<_rK4kh`l z|E4Lx`H!(k7umZo+uoZ5vvHnZ{-ljBQ(y~V+nCRU4I02-s6oHsv@kR2f5DK2RJV8A z4l`w`FW(1_vO3+pFtV4!4z`TAOqVanwAK^}Jo~p1Bn3G$qeJD?ASr;2Q82+jm~r3P z!ODx<6`KtEVJ-r(8ct+AKbc=|J}759a$ zbZW7%c4uNu_QrgB(=Me<#+7E!`;&F}kUSe`O8Mq_K2QVasFNpbzOnoSzl_mJu#7Sd zb6jrA8SQA)8}r;oist)zVFVVHD7O9;WA`(2Jb^Fj^gfgOo)#`61;TS>M-r6Xd>KS_ zrB>U6xx7VPH{LDI`G9^EsaQV?6aW0wP}-TeOtRySm`Zr+=~gt$?z*5m@>l%T@fZy( zpyZnom2l#twNY&GM`1}Wag%>XXuiE4j9tE!=47*4t;Pk<#0cZ?j3y+=IJx~3C#zPj zg!SJl;cz<(lJ+*>ktMJeZ%WolHk~4U7kNgZ<3;O*?AXg{Y9q%lr^G%Fpe)#Um9tja zsOvWQVRJ&%t6dP>6ESQX{W~cB8P=0Klum^<6Uf$F0+|U3`(lz+6`1}`Nj?)36j_1 zb#>Z1tV3tJ_C>P%QB5gex?*MkXutnqSkZOrBZ=oxvS9_F3*hwXlA?}x&67)%b!z7M zSpFE9s8%3AmfqWF7i6c8B1LlMji8xsw;vW#xf?9L^%3e<5z zuu>@~b{P(8%N>#$*&A;MiZIwy17AxZ3ACwiQT!N&1fjk5=5HoDy(vYxeQ$=JC44(Arb7u z?P8^^Y2EW3sP}0e<6Fv33zl>jN1v}JC`V}+?_PpzTJ2xp3ZRUTE8Hk&=b+bKt#kgA7YY~e0hmbkhl9n!eS=+d+cds&4E6ky5 z=hAUSsoID3#0Fd=H4`|)g07qQY^lf1a?x-PpiVL(zyKDFG&*X}|AeSNgfX z`GNU8;-Gnm9>`^{Skuut*s6W*)NkxT&J*0cE#D_`u8(|C1!!`#`Yk_ikxRvP>Ai=Q zud#}9tk1X#vyYhUNl=d0wT#QaEEE?A7dLYU0R&R+J+R4b z%uAYF6NPih^t~-%V5*$%jC;2W((c=KZTP4u$0^;dc_VUqLR*ltLpGZ%RhB1c4D+xU zv!p$OKG)3)oo{nzUa$Oc z319_b)a^rX!^4p793O4(Q@n zcj0jsXZ(6^iJcfR-a~3$j7mmul}cNym@h&-^8}#h={7%WyoYT#lY*bSgw&EKTE%j; zhrP!C-!BJ3aeM%PR;M_FF@@PR0beLq$gJu!JQ$RcS^Qb6sqs|z8^DX8&Khr^zw+VR zmG~`BW@Erkj$j zIol)TYy}t{y;3I=h+9(2E~f}6j1YO|d*3+ON8`0kS>_&g88g3Hd*YFbn=TS?jIH6Q zEig6WJAqxbhoUQ|7Kim&K(KgqFxy;&?`CQzLsILweE7eA_gm)XELP#st(km!mo249 z6vcB${L74T`@8@Q^usl`&q7gYTrNXd=|>}Lfj8s`abJGBxF%Zb1$dUAeA#q9;TSG; zqVV`cp>+|-XS_n%SM|-t2f;ytTqq=FrAu}Y6aP*D5;Flj9j-@XNLIoyUds>TPk*2{ zNAJRFK&|$HIA0dBHlouv6<`b7upgUP#_7Dg8){Pa`D4!G=z0cr02af1hPCw6Y(;ht zofax;GYDOW>Me!MewoSREZ0j(=lGOL!(fYE!>eBY@Id} z;M77tN$FepWS7m&*-14}%#NyUrkTb!clOdLKn`)dP!GI>*;ocW_;luBhM<>mb2!6O z7q`y?9Y<$t%z`5iv5c33g7do-f<#BEeZ+|EkdU^B0NBvOJwv`eKWcjpxUMOtzQ6V^_${L`E+*Kz z?6XW@bV}B9lCw%g8Z7jGJ*v^Qb4D>c{yA)%Z9`1dHrAOf6x#HfA>%bQa&S ztfkA508o?ge*`2b4?l9^qJ!~GMe18tX$?lw0Sb!hg(5;9dgEAXEYPC4pFZprSM>e{ z5yLlMs6v%yhB69YT!Vr_{pa7|<8%tn70uT3`5q!l?65Qvn2njX<*`R-LooOhA`@P{wJvrt@0wuDIL(d>xNZimkvdQ6C+0J4fX5J%3o2(H|=EP(#p= zLs6&Tt~P(u%Jl<%-CcQBx5cKb^6b#n^G&CiWu%=IoEEX);k+?L+4*P#-y;Aw^GnWm&kPk&eT-uy%lg z$zp}ZmdV`(K)#4?nVm(?G~CDUbe*eBP+p{C8JEU2{xGP>2y`dK zyD0xaWH$rrvH)-*a#r#P+R2Bi{A!`ed%|VDKNP=e0-9}9G9j-dHTn1s)iBV34x6om zh$vA6kNfOc#H^0HERVKA%{lMf;F&Xx8IsE7N@2Vz~*dr zMe|%MXyV%(l`jxVaP@|*6`hWvP;I^sJwz{9r~68v^6mp=#Vzk)kE@A44Z^#x=W*Dg z)UyEctu8ui9U(n#QYVd|U!R@$RNH6*b=26I2uX=+;^@l&`9529(l%n*T-GincYYSN zzFHnCa@Q#NmV{_*I6v%3-5g4P!Nzi>L;IP5quLk(iG5UZQ!6cYJ}DM3mP2E0#29PI z77{8F6YWRe1zNWBU0`wn+SNBFaNCWAEywyle**6xv9=R#Wob$#Fvj}w2IR+hmH%~! zH<;eFhYBOFyc<$ePNo#z(d{a@cbk8HM(5QQCHKW>|3XQrFRoO8!%UREJ^&87ZwVjd6tGwtEQfCuE^c*+g&Sku+e4 zX_#>0`4r)mey|_ffxKsd*rLc(f3=p0%i%X{+$oKH6`k%}0<1p*id0)2Qc^k(K%b~v zm1AF;01!>G(?~~SjMAsQ+c;u@x`Q;~YWVe=%gxPY7H1+3!eBpV_Dl>2RtY2iBoxc@ z`;PP0n3!7g%K#?z2Fs}MMZaP|O-AKC%&`g@D=<$Ki$X+BGy;D3)X=lpYV+}3+G`FQ z*Ea{!;4KxqHkOcCZ$I?_7#>shnef*l+oGbo3r%mXV%z7?QH{llJqofbbuhiN%Ah1` zBRP(E=ve3!-vbn?ua`Ijf@Fp8ZA$eW?gHf=TyL3sizeHfPKEI5KG1WTrr^V2`6^^V zd(l?+&t)Y2sd9*$QCc0z^@Hr}`kZmhdb^Zr&Ru-d;!o++vU#l)<+$T&I%=MyTyJQE z&i?XTUo$45HHv|NhO#qLe51JV8Yq)9@?v)%R7DDk0JK&$8dIYd(h?M~zVr|T%l^>X zy6mgvN$l+FB?!Y|%1a0Y@a;MF_htJZCUBw>E1D-IPpq%TaaS;f_FM~Z7=q>l&BP## zm+Wll`E^7dS3LH!dFdo;Lbdw(H0!nS!fYCV)5NO^N7&ekL1O7EA>gE>=k(Lv9#2>7 zoJR8K59A3G1c(>&^z?LsGvUgwz|$U`>^dqHsUd8rDeu@m!9;eHR*%a05Ii;6+(=Is7^=fQ;x|9Cp1|ixFr-F-@LGh(wiuK^b-@%`M^Bk03C zFdO3?`oi8L4-EWgrKO_(a>u^_tP+LR`hW6(NKtIwY000?#AF&hXCGHqk@A*RVP7UG zd~oAM3QdEnW8k?pJMJYXt`w2USHLV+eFC)37n*wbK*yagAqKYp)h9tuAQ3lty@i`}YQp-v-DIfSGcZhWzG zXnx(q@Z6qFQ<#VO+xBmLQsiF7#V{9F(g>_MOqew zk+E}=YD{#I=Ru(cB?}o2?A^yu%|apkAQe?0W}iO*T9!CPMo>T{LP-44r6v$nOl?3{ z-?Zzwyjdh< zGlyy;pNnjB!)4g5KV2-HlwM*4X2_J7?4yedSd+tW#p-L#|)4(Rm;VdgG}O*L?mbX3h{oFRWNk!-^cA% z!U;vX!;Lj#Mm|TX_W?s4c8683#YssAE|9Zujf}Rm^y`6m8yEW{_L8!m)^!~w;zUvO zFc-Z%5nOBLc}zA*^_kwOPDJoy)2FaSYEr%ps9lDFxZv)2xGXNH~TF(m}A_Z%@;eHxhw!K>L6;>o$qe}22OKbBGPAL`PTx7 z1&M%Z&DM|#fOuOsG+$`Rq3xjO2$U!}gj%+&WZfD_b8AG&LL(WrR7~wS z=s{`w0Z=1iJNV|ewe;Xxv$=8FJ+A;$%+|p|>{}C{q>u*^>xLqn?pu!h^)2~@^m0WT z6$L<<9Zp%%T%^EQ?eLo=|El&mjKSvsc zu<=kT(qa080u-r1lUA(Y^eIGOMFmnu9F>t!!JQ2TI5prh0}rXvuR(uA z{2qlWgwL=^MXE1Vz6PkF4k`-FPyTKq(*Tqk@|b&6ZL;P#_9%1p*!N?<@$)(&ICjC} zOm6rO*iehF@}>9JNK3dyBh?l~v2$ROBh>GpC|%uZO_MNby$%_*n=#G!Hk^5g0Ra4bYg!LW!pl&(M)a^wQ9 zC>M?j2?E715D)yewh`fqprmY&n^!URwlKpegEpR&x|%n^hy2H=F|H-b;Verc&ItCK=KSAnnu-~CQ4lD=F? zcCeHMYQ#?mpwa)KA`ePo`d<1Zy30~re*G90>KVx zsjIL>-7UH!J(UlLH4uok8pc+{aXJJ-f=s6Ic+K$0RjqC9%}uxHkOL6n72uHYC8Bom zpal8?=+#HFcoTA-A@9rZHU<@Vh-@A1OymaDj%+lswFN}LX&L+A)`%nOGol|VT;Ln{ z8*wQ8LhgCHwsU`ZZ1y~GqVh~UIXobC+#0WtBW#%&a9~>B9F({Hv)NRM`)^is0})iU z$*3(T#8xpibaDTe`_m0@oiIRj-^>3$>EbtOnFJ6Xa#KgcGZsVEFb`gtK%AY)2Fxzc zG)4|VswD`+g&)4>Y%%W%2MjlRkPewl96`0n1-G)T6`6(YgkhNReDHVbA0pJ3zulJ) z{O904M&!xv|G3{yZXXF?CYwUTA2vILvexm-;`4X zab%pGHsb+65n*WYpL<$<3Z+{38QmEH_p}V0(H2jnf)FRCCTQ9j3KGW0{-^}*}!rwH%&IMQly3VnP-$R@jE)+@J&k-5~b*is$TKTimfVrW)jwjEGn`ltN_c^ zzL3LT;!}3;hN$5e6gw$5Yq9O|r~b>$vPi*WTP5h_-&z%=vPZE7=>0Jb+c6L==${C- zC8ucwIpGX2B_|Z+W=q()De{^F;#6CoT9Zvcy7s?*N=_uIBp*R&BpfDAin0?z6#QG3 z7P&s4R*_C*t4?UoCJT?q-5cuObEuBpSthlQgd3X4o_qvOU@sE7blL#_zqJIagb?qU ziP{wZX76skgsaJc1SO0-9hPS{G(kb`3{l@Wpo-?1d63%zN*3W%_ITUcHFA|j@gnU8 z=AFODV-Vuiz|hHvt4C}GNk+ClW%eU%Q4egLEGI$TQ$dc)Dr>m_G0Sc16dt-TOooN6 zv}1pl#RCVv736(G9wW;e=zDi8Z6W*#_#&QnG<(7;V7*2WxkJ(A&ZVvgznVzQy{9Pw z8!89%7m8|rUl%Fxbd!g9sFE~<@Uy?kjoAkw{Ig@~?`%R4E*rcT@*KH`Lop!1LiX=! zvN>=50=d6MJwPHcs~@aT3|qq=0V8w!PJl^U!%YvdC5`zpQIKP9eGIe%aXT6|Rd)Toot2UAeDmA; z#tJ4Ud3bm(>3Q-*1lUPQ4IK9=iAWuFc`#JOE>Nl6`Z>RMv}0pbpp=@mkZs*?VWFKH zr4%$@6EQ{tVHC<3k(TQHNdP5V1OiJ%m|b2u;OU>zGq%;K2(BR9Ks!1z;ddh;jNCOz zA;D0T?uw}6XR^=`6_18uS4l2R7ZoMvFe4)9@(wL>Rz@MuEvO(*p%LkED}wA(LJF?L z&rr)q{#o83`nyFU@5eg*TR4w$Mz91_a<^^i{!O{woPB%^=eys8+V40xZ%I=;>M}`tvLa1-zMnIRq{L+Rp&7{1}++rZAYxz z?e1WZnzq-(#KhR%k+S~pfqGw`?rPsmd<2kq)tdmVB@d?io={>uBxc`n1YK=E)#OFX zd9&<9u;$B6-By}mOb5A8JX?D_PW9X-E5_ofm-fl3&5Id5TT2n*VmD}APh%U|{&uW2{EI!;gWl;T=mW73- zee)TJbEX8QeI}@-wfWjIct51`g8pV*r{ohsL#q&?;tf-amW?$XetjTr`P>-cevG?2 z!Tg7_b1jE36W#ggn@-?P8v=N`Hi(rpCAn@axCU4_k7^L7Q2K|U5oz@gAyT+4ryJHr zypeQuu>s}E54T3-iM;@DS2RfFk=DeDaIHv&Q9XFbMPCkHOa46sUr^|&5t}NZKsPD~ z5%(kYj=ZJdG*5$o6YlH8srHJT)tQ)& zAI}nXV>ePK+ghO1ym)Os2L~FHi3>~-7O+GP`%$RO1aR}AGq;`-ffotcRgKK2mtdSA zWtP|fQcm~B>$60uA@AT*dPa3tp&bbhrkExQ7kY9AUTYASaMK~c<5_||Dd~xbipFGS z3RDg}ByxI+Q#*nJrIdDvV6yQ%ktX} zoYb#eh=`<`AReLY?>Aj0swHV$L3F0p=_!lNyCc@)X)anCEGqX}>h&^RktG7-_wK_4 z!d5F9QDgM;oXT#q8idh^KKZW70V**i;-But7LB+n56DiMYBr=Vb&mR(h)554Wch=; zR21ek+OH3$&{N9eyKoxw*asa_#a$E&W4VF$-ZUK_DD`uC-6e%F61Rv$I(u zEDSx`7y=oaC?g#K8t$6>@@#Fz>dmrrF4LmQ!HULI|C(W;sb^nCKU_rx>cLTE_0Gq9 z*9qg%jsp4$sQ)>=o>!N8Dgwkbuf;+8-e_k!&GDcv#WJc5GI}-Ko+J7@oA1DbO7$f^^raB6IPEi0$-2Z3>mVHTpPZg9KinI|{5z|!{8d+lf8}ZPS4M0(OLK!= zkMDBPWoj+KvxDlz^=uU#Q}af0DP_W*_h@pxfeG#7#Ua~kC&eVhE@6G3 z&TjyAl+xZSA4>nTo%qihT3Y_@s>sE>qH$TXQ|$+LPi-uBTb3PN&1}Jf z_x5;~%4N)K>LrB}(W9l3h2#MGD}3-7Q}H3u`86WHjI%9moxwQR{bEJ5WoAE_$ukS* zkNxw1{<|tQa7)ZezgTi7`^Oy)I*(k!JYW+~HM99jYHDhLj`**;xxLXudje6U{>e(1 zw!;$40(q=tvl*r@{C9Ec!K^GH&21#7m-JCEjV^MQR3o>tgXKg$p^DkwSnlEBVY;FW zz$(nLOF78?m@sSad2{ zicf}N;DlNbChG=nKUU#DRzr6Y`d%R44PH!DlaG+lKjPQ1-ZToMe697|X4RZZ7jUY# zws;7f^?@#j%x6_cY! zAABr7r|vUCbpetLyRROukr~bJ>+7pLz4uAcVp~x_kmqDV5qUX5WxmT8ohE_M#D~;F zU+A-EcfzDVVAeJT7+|$C&{8egxkG{SZu4u=bryj0?Y}uT(}5&bUVT~5nJX^wD}EDr zsK*c3E>0eBpcT1a%}ZJ);T6&YQY3F)8cWCR!?bN{I-c-`InMSPMl6cax!-TDCk04$ zRMrSQ+3ZCrA%9TwJVZN>7Z`sRamu?|JFRL8phg}B>bw@4y~-quW~+z%|6xCT_0mh- zEFpGw_SY+t5od|N6bTHGn(#&v>**J-0au&b*i_dDyxep0QRxT%Sh|~yr@bf^sY>eo z=Q}*uSA$Fk@x^|Fd?^XLj;Lsjf6b!omG#`oWi&Rfv!i~%&rdyPaTxt?LQ@bUxBEc= zs-f`r0RaK&w%>BiEi68)SU{pD_5(Wg^#1#ZB~r1VZ(zVTT+~eYImp1?C+mgt57x-f z57G3ClmL}lKSC;$;b`GppwI%oGf?(f$BI~BHO-;b}--LkR7<9{!DFqBnaijBBa>};;AdZ zJt2Wt=4r4z*+;EZMSlDqZGmjKDRRey68`!1W9mL0oK!Moa@ggIcrJ8to0+jRAP=$o z<3k$R4jLh$mJ(>>hqHfZK04%cCWm?mc7w31fHFT;_|=ye+R4VP85aNdS9dvt{!#ka zy)1*~FcL3_Ighlr6u3D989aXSa{Cm`=M~YE0+9j9d-0LO?&D!Ye_o+V8T$GLhFOf}5Ba~(9e9yU z0zcx{XOPcad?ehyg@QjLa~&I{?_QM2LNhpgf?!}`X?E@w4#VF%V z-roPqO}E#xmfJ0Uv7^araU9aS-{Y#oB;-K>zORuR#vlBfi3z+KKmJ(vL-4`VvgFci zz(X{(X{3%Bqpky}xfP=s;)StKqP3;s&3Ylg<{LL|Om`Y?Ws|#pjpoyAz?soat^N?B zd%Q(v2D^B{$-nu4PlIa!wNXkG7E0)>gaeB^#7dNsKGkGSV^?||$d4n{7iVesahID* z@a)UV6%a<;iDJIRM9*>_h_6$ zL5=OArc_J>1~V8%GmOBy{O7V$=38t;)eo*`74kozn+f}BV3Z%bAvKym@oS0bgbwlq zN`UGbYHhrS@;!<-+TBJA?7=-lGU(03OM0Yh>r<{=bRVr7)U?aOBbXh@VE>7U8FP_2 z&S)8#UvqL&1}6qV$Or#i$ak@ia%!cEVQ~Cm8IZdlKLC{`L9$>Tl(ESWBBvbrMmPqkl#MzDu9sz}nl!&|zkRLYIDeS}P>N>m6xor@Vf*_ox3omL|G2YA zoUm@zXf;y1aRe@}1@gIcVJIvej0yCyGBNMIz*Caare8-Tzx7&#QWVCFwf164%K$s58< zf)48aS2#4-7noTPjk$1xW^b&QCmj7{I1}(%>nX`abICLbNcszJ3w|fH8_DORVK=O> z*zVZo4QEnA7uX1=TYi2)|Izl9A3am^E-$|`v?0kA>);I;nVFf{;LCyGwgUBQ|MQ*- zo%Gd6tr->jTU+?{hg6uFDOOVtKTiKzy;8d4;1?jn+}K@(9inD6^j$?DhozV(J?;J1 z*T2*!GI0j)AQ8X>cUfrR`wMG^uGhKuoVX+3D?wYfV53tn<+L!YyX9S6D&^| z`G$L+2&QK5I(@kN0ewe@@s5(1%}#wfO^m=*_Jq}!`x~B;CmYoeH-nST1@iD~Wyn#7 zmD!g+0S)XLnN_rXPbg|lk@`1CZu;Cyy&R<2*f}_`CT(cAPt&C0*!N~|B<1lbnV2B^ zno*N#paVPt25NT$vM3uVSOr#rp7a3jY`R7IX|kyr4s7QAHq6AEhU!(g0|wKI zz49$$R2!B;czuo!?+!MMzV?6(K5Uq6WzK~qvkM@QIU2U8p?Vrfuf z)n_q_aW$n=2FJD6mQh4_N&1+*fUY9}W|9Tip;ipI@YYg$VWC$6E!iaSFkfmm$(MeT zx|qcxGy}takrkt7`vE?T&{|NQISptEfA$>nS^PS*4<;!@tV)z2$-M8Jto>{TnjL^&2Vu*r>m`V z3Cze&)$l`meh&*L;dJtb2+fVGkJW^&zj4})jN;eDDGxzZ3NZ^}GH`*7KR(z?Gu-Mk zn90PLD+2s^pZQnS5F^eCsmRWfYVeLNOK+}OPmNgJo=o(bOC+ZUdC;}@9gLXQe%?#n z*Dqo)%(OeGxlbh~ImB=M?Scbu}!=%=(7TGbLuUmEWm~ey(7^`S`(SB*Rm*$!mzoN3D zbT=EQ#;WOW!{oWtMZD+TuAr=qrv^TQHfU;Z-!a&n7WC}dGa75Louz(GlP@rEMSZ(9 zeFjM){xi_u?}OY4XJ==$PHylD-FOO z&V}uoS=2#@Pg!d+YD0;;p4*&J*9;n^XP-^4z&Ralv@gPWJgWf9C02<4>37* zQWL4=tA~wCG}AEc4jwC9=oYSUE}H*Fv(|7PG;zHnj&;g6??>ojG5~+-6VR@gOhpGb zIAxql-``#jD#fZgLR`e+eylbMzoED%23`L~*}rc>He@qgYXVAE{g2qWrbOfUX;}8{ z$zBL$lbgZsq;V!S^>c7{jnu^x2v4Y{6>WzC4;&&e`rD&drqc&SV-ftR+ zVR3)sGu)EMJxbc8WejETR3P&2^k^qHgbI;|YL!+X5hy!eNnR(b`@$;{Q zHvmN!aWdDaC3om`6$!%V$l`#HCu3S06HUWBw=qr*KDVn$P?i(3F zoM5R(4@-GSUdl!*Y0^}qg-MA2Ag9Bqa$)*-Bf^>O^Hm4X=O05|BB8J3<3qtN0QSg5 zn#u2#`#nAOBF3fO zG!CS_0rA3c_6N)h}%?Vn9bNGkm)Fhi>2Fj76RgpkhC*!?>ov(*K+?* z86(PzP06Tq8vm~l6HbL z_<_r03N98pzE>ri=pN6?PZy_;a|IpQ<4N(;iPzc(Qd@Ud--Z*8U2$U$D2*;Swoh>4 zo%02qeqNNjUNp%kV6XU(B8Tyd^>xAe3~%k4V**`%&Vuw2CNDNeqTMgQg(&H{H>Rp! zUjn$~q-}HipP7#xMWIG`RxotV@w@)MCePF=YlyjHiyKkKs;?wf4t}qAjLxkPL|td` z;r<@`_kZ!zoG(HV(#ndyork+$q0KwXjTt{U(GBdf$w&ucNrb=ctHqm3OExReLPGn` zz2c!M9}2FfdqVkc(TE%)YY-M|mw~tF*Qg=yJLD1kR@gNnB7%g@E&l1#J+zeQ%FkVv zEgIoQ@0}=e6Rpq@7ELH-04G#2{b5x%^84ZwN+gfmS+jmyX}DQTI8Ds+ZzImx%NNM8 zouAXg97PmHZCnnwLdA$E!G)us;4&yifoxsdoIQohi=zlf1A^g$oPEY;Z@?z&QCyIowz>Bk~dUO)(z?ljQkxefe=_%&AbE@iHvO%F0cfiP|9#2 zlU%wV>L8P)w)9ht)WyB*LX+D%Tse8hV+?pq{duBc$JW|{ab@_gMW{PXAm+EYph>yE zkGRzrI>&9LeihVQfPI241Q=}Fon}OVTni_lwIILH$_rj|@W0<ra8L4KGnPMr(z zDZIX|Ei(GdzbJSEd2JyY7bamcJdcQAC%x3%R0|?t-|N9)GP;xsr?-m~u4;GexgN@L z-l>Ba4)aPFG1~FVdHZ*3SbwbxDL$TW%hyNmf#Q>#cAhNdMBK=rcucm?-VId3!T(}B zroU1U<^t&d@Ne_NCkAH1v}e4k(K zYJ(QsRywm}*nDCCHjZtBZ(Kx9&3g5&z2DY+g-Jd%kH7lrkM%s8^Cy1zLgkzy?}zTI zBGxUhT3%f|I`H6>GIKpxCJ{Ng>P3ri2`Bb6zy;o~BN1;xwGQ;;{K>GJC zRDeuW9wS$9S5)5P)ik?w=J@gB{+u!r7c$K0t&@}}WGX7x&qGCX?|4Plw(|Fv(L}vV zUu3vby>nso%34C*pLaGx_)91CRP?>$w-`4Erg;oo(TWU-!Z^#5|Eg_TKjHEOAGAFi zHFaOL9~B+NV-uX^7j{FvU)`Tv5Gu}IUYdu6H1tH;ap~wg!k}NO=KQ`MpSQ4>)rD<& zqpw_n76LO0%br>`a0w?v@bSh@!x%&hBdaejeI6gZ`>dw&C5LVHPU&lX!L)`3%1yiz zj`_(^_}g_K2Yg?+T>2OVm5t$=`kJh1wF^+ePN*wJ!M}6aS1u0tpORh}E9sKyoIOoj zhY>RhxCKj2bH2I^70Ob~+^+nVRt~V&Lzi3UtBk+h8G!dGuD%YUXuLonKNmgp+@>v4 z#I`+~RzX?Jhqq#9{WZsNvLhZHtHAoulmHdOxr5LK{w$v|O+wtUx;5M0WCOdHz`(%3 znZ!#CP}a2l8{92KQ>oT?e&o1svftS3xA8jvq+|z2rnvNc6mC!&_7zO^2CNNqcmlov z@0|_X@6+gz#E97H=fl;`gV27xt|os^a@U9E9$=Zu{~lT;8&5Bw(oe$1OVdMN?LTYl zJcV~aRZ>~04|76BQ>ZW%lg6=8K61l_T9=y(_2Z95!W>ccE9D9Y%3at{$2<Ou*|x^aLnm4`)QbUts`enq7N|!ZXn2u*RZ&bz=$BLE;@v zZG8^me2d+i2xaETkV1A5k)xrZA$#hJZbE)Oy)#}N53#P4aT{$v5Fu#=lv-Ew(lxf+ z&o1wewRP$SE)xk!=M13%f?!VIkG^*1yaKy{@bg}og?0}-gDMNtm98O9J^1s<_j zUYrT)3Q|A~;`ph<6^WI)1rl}~9&S!>wZxjiX0=cLL61^Ff2h|D3|xg9H@uX zF-NnWgDNP1eZs;)_Fd&lpcrXA^l`Y1F}QgXSBs$IDwdb<`E2rx9V{0Hd7jq)+fU&+ z47wmRg{*MCmbUiO(;uGr3FQi8c6!d!nS~cy)-(^3pbn_?kCR&(Zx5U9Ko<$E8H`CZ zO$4?ds7JA7;X;ZrETz>VF&f+l)I*FPqA~`927_;1+gb{BDnI~V!2}H}FH!LU7~!#( zu*PhJ^qJiHX%GB@c>Og11Su#&Lu%f_!or_hK|VslCW6)>Z;mTEeiqzfoIVXgWYg+a zMJhDw1Z|^!WyJNmQ;WxHMY2@?lMv+j1BLrD8c+K{CjK4VaTqn)IzW-?`a0K;X$30xGX)4VU z=^kl;AV6{2FVoSL z0)MXT=(b_|lBJd{x(y?^dBnxt^kqsrjEE;BF`}x8cBeKQv?I;L)Aqw}o9R2+f~yI5 zCi`0XGDkm)pb@2e1izH2&J;jT0*{@xmaRX)0$3QK>DG%EFW4lk&;3NF7miYYeGnj| z2@?-4QCV01*likm2bF#R_xgrA=ztciJn5_&vDNBJ4zx3reQ!7^9B}@P`_Y>$o1nH5 z#VOGu@$)aP__%4`V&D;<@dRmx{M=kUJ@>&X4c)e=(J;tLLCH=9xwAT2 zNP;CaAlb2D`=Q-0lHn4xOASdH3C4pdSumEK`l{VHeGX~R(G+Ngsn70yyq!GvG6+ui z=k3j`DS>XWG5DkyCEKqYMU_xMWQ$_R3=KR_q!W1@W6Cw%9Bz|=K2?8NNpM!@Tm5Uz}`|f`2O*?o`;poe+aX zmlETr61Su)QS+t<yeOt z1ehq#xyIOXD3yfy)YNw=;)`kHcsSE%zJT8912D@kGQHFf)MZ!F7h+}Y7&peKY!$&e z9ui9OsInDp@2ypHc29&Ay&>I|%8eOSPWr}GH7PK0sXnLsaaPw?$!);Y=2x86u1j`A z;j6=5`Fu$;CAayR^=1v7hv8gI@mlkSZnr9WsaI&t>m} z*?_msXXNj)rF@ITs_@E(J)x!G^8~~GHpvXVlA3NY`SmK?lyh^cI?a{?ykFp@J`j4J z5U<%&FJC0VG#2$yG^fbUWA8r6CN!l)qQ1?ph%-S2##`Py4t1$MU@s7Z;Ab2cdaX&e zs(6&%8=47+u4GX^?#t#n@iVCk-NV|HoEIClH7)-Qy8rE!4`4mxzPiUCG||#~LQvei zuRZ~qzo~EI!9D)!b%i<0nLvAU)$l?A+w7SXIGd>+HGt#wMd1-xB^AwmP9t@P8lp9= zX58ctQHft`!c*$k41+!O-(>R<5k7)a>F=`8ode{N_}cn0K!kLM)!4+*25g=Yb9JnO zf>3TBjJghQf->naDSP7w4dy=#WKa@XqvF`FT28UFVG= zAxb}970exTeFNrReezjB)_!HZ<$s#K^uX&DbO)0yNe|%lpm@+N!2|CcYUNl_Vho`G zW4{9=0QX3J;D_d;WExs8_#|59(d z0{vIy#Afv#e5c_%hqCR?+p<4diaYdkUM+ zQq^dsKU_FA-e%|{6B`E8zbmt|I&`<08O*phHGM%`J@6q+#Kd@@={Y(+P+orTZZXpR z>2+A3m7eCGKffcV_6zjQYl1pR9(9-FlchKfu<#^wEE&0O$Q$TCXQ&+l3OI>>qPI5a zq)^Uo@E9J3hN7~j_$Od}S2Bg_7?}!4cgncTk6Cky6>@maob}K10>Hy#Vk%4#JDuAT zeh#!X{6IYj)xRVqB_W1q)w^D;=1zVDa-qg~~G7AGrRQNWe8;*h)9vaM4>F^o! zoHvD;)0Nfw`O==(zdkvaejVL{8poA#E`67>KO2(g6DLoav}fDbXIM1SO)3*MsREqY zx;A6iKu-8}FW zY&!|c{7wMDERdeE0DiGyy2qa4+d;=oMd(?6{yx+qor+u0bVNI`Z7{CDce zZQc5i0}bS!dinbGq0`d+FfTdWv1FsmfmgPq=yK64gmKY{pQ;XA3q zxa0n382P+vy5*!pSB6M!4rv`hY2fC1A@9u!3kQKIJeKD>gP^YY5Dh)3U&8RpQbFp{ zU4gyF^%>kVuGuV58u*Rp=Br>SYx)rQyC)L#UkkS~dV~(Ex zY;}{A%_j#9jT{ljO~H42yz==xuYKqOpC$;M$VT&yd{>mzLzW5rSk&C|Kce*n)A9PA zJxmIeau5{aayOuh(itIlv5sZ)%OI^m`C=ya?#nk-Ou$zHT~A!=MDpP;2ge*a#n244 zMc1wX--d{|FU}h~w126LBHTU6vH0j^UMm_ed&~v2b*D0KA$19RdrO9^u#JA z@NH`PD#iM|#p$*#=$S!0xHvx()TpKaxRT1w>OavEEC1`Mjwlsjw1wdt6#g$dpGzwt zx!7crR$4yL+H_cPRcz|lze`fDOQ#miR41iW;{$4elDI7p_H#mt6DrA0MM;VL?$HV0 z(F>Q#H*fyB^JC;gVEd)pZn5E<3i%W=)fC4u6^K|_>~I+2B{%q&q1`KA{kS)=biBI? z#o)G^MKM7*QI}>Mcvj=J`kx-YP^vL+3j;VPlBU(VOr3e(>s@*R*1);1JsVn8Mqzwb z5huu3_WI#Nn0ZH39)}W#3UC!eKtEic8R2BrzotQdCog{IgRV`TBA`1!o|elYE{9CM z{YcB+^vtU!13Tgv_Ah1}LYBLoGH;=w1A`G(6U6R=A!CzwcYzp4#KDXIWLSQt--8;A zFC@BjA}G~{#If29l(+mXm1i)bzV`aT7_W&0xEMn2`uJy8^6PF+P1dkt1r zvZn?vKjjSxlo`pE>WjNYm+$S^*Y4RMJ$_Py)VGV}Kee+dL7J%^5;dGMd%oBt!Dw^q zoB01r^N{d#T3nwF0suWHb(IDnmf7TT7=d)D!8Foln~tFeou7D2@b04ub-QE*T&VMz-=$HXqR`a+~1qkVnU=DDXf;3gCwMo z?3NZDhsoMpgelTIk7qi)6i|Z}5>2R<&F+5|`HPsi)ZKJ(@=vhr>IZbsZ|@C=+;U4O zmW6x29*oeZaHK6+9CtVS$!q^PI!&p2ZL15Y5928|mN)#!|rfJyms zwbLL|+ci7Dhb@1kRWC>d%8lM*n}~eA_Dpcg{jGx}e?vL$TWKzbdiEa@SWDQ2YS&O( zCfYvg7{UxiL)#_PS2oYmuUN7f>09mcg%bUchjDdaWN7G68wwUyC(O?Z8rC^5b_9w zUwaknlb@&fb;`@hY?7H+g7{LSafz~y8Ut~86w=oL4+V7)PIag`%)fR68$H5J?dCs} z0?z$u+3US9*kA7Yl-fyB%LJECtp4&p*k9bl5)Jt0gHCT*y_o__=;DkZ`NNAT7>!Pp zooF!$DuXUTBm$Xjws|O@M?bY=A!vYNaI55WA15=HyA0K}Qde8*d^OPD+*8q9)F4qS+gaNX zOdf9sUxFFlcoOt1@)dWue7`}nSA<Us(D& zLrM>o_vhwp5ca%E39Ns1`AyUL;ATK}y`6IcwytBtVIsXDaEpy1g9}7=j3K{`z#(Mo zf&|m@>4v-55ZMRCh1mg8aDsDSY^T^^$8~i?a1=`^jGcjo3L=7m9bS=ABV(K<6Q{vd zI%4p=sv851^cl4v9L`}NtULC7e-i7%Rl3_{-(SoIdFf4eVmc zgSc2%>hU+E4l~mV2om|k?#MK;Zf-nUc&TmfMFh%pWOtFnMZ7*lpncrP8UH9~0)>6B zkA837KyOrWPN&uqE-PD5vamS5VbOG%GOOwZ)9l6WgM;JZG9HCq)?$@tYv2CE&{wW7 zLup{Yxfb^UeU2H~MO)di1xBjgCd3hdncVDWyF(L!TQhCi%3IQZdD;s&qe|hpJM>kC zPyP^ZhwdvF;$7zrRA=opfUswQlvCd~kw6*u3Mm((QS&orBcwO79}{9FkIVRVjU0bt zJf5M1;L7P$gmqd%p?=`xmfhJ1x8q%E(ALr4`g^iuY%`MO?{*I!Jt%!%F~ZKjJRS~K zVAq=-4|atji>rDL%^dbko^7RBg}FA9L#=#{+nW3!8A7UORq#hv(}6V&U0=a+1Er`@ zVi6=eqk_w=)Q3NY9$@<(i?9}P`SP-UZbhhbAG_VB3lL%bw=YC{(1cVXY7GC7Giv|o zMq3%vy!F=V$XWG>_h-OpT6xV$dCm;Ek=vWrXy|K(O8-R=*a8`7`zR1{1nVGO!FM`} z?dj9h(v@)D?1#ojzz1qkS3jkaLr!u-toEl5@KEU;1h99uB)~OCgB~XH2H;@7 zD-w4m4ugOj8ikMCg^xtO+c0cWjB~Y^-`SR z8!eni%67WVjvIrQ6=h5GLcQ<+QA7)~tn#HiG{DlJPl+uqc8kD2`&q9dcJwL6%|o0h zA}e9lxyzA#`vlq7kvuuL9xJqrnb6zInFQjUOvqQUNnbMuXteLZfdIg8;vcIJbf!KF zxCU`d*t#v@=#JNBv4DdSc5{6ROoW_4J8;NTa|WjyOZwX zswBHOcQb5qAU{mREGK;zp)N@|aL0Bi`9;{aL!cboNTB1HEZc9!O-UMb$(3hDP?{on-N}ZzQ{JeJGTitEwzlS3 zrUa{S?YSN39}|=%; zB2;Dva7dQN9k~^sS|d4g{0F3F1EHmY?nY5gVbM}xn@(9G-`&eyp1To{w1{13Bdr#I zxJtsT{zcdubM1C=%O?T5vp9u2R50-2;pX7wBZ*E2jO*^nlugaRH1Rb8TeLZ> zOg>I3iCmtG&mO4pJ|L47ctggPDbKoj&syu|ufZa`C>8L0R4Aa?)N==jD*KyWFUzDq z2|rVyH@E`wt>It?Ti;r-;^5wj13H$23{|n&Sl7|{veB4u&XZ?f!mM$Y$4XU#q@~`@ zp!rxYgrUP=`fbXf&`D>O9+tscqjLm}%Nf16 z_SQsX?Lx848X)yQe5b3;6t2WVqn6*UWSxL;?H0zO|VgY}`9H)z1P@j2rT8-kWS9`S(}Hpt=L2{#LZz zu*1$upU!9OGZj9!U9dS@`40liUDWlqQigvxSDRkn5d&$bpvvFfo!lNK>`HYia zm$Doey3>GX@P*ZYaZS9o2rwD4OhttqsA2+S+qyBq17%*WU$fwl%gyu=;=FLDH@YYc zvL5>(7Kr@GZSMIJ+xF`-gMJU(ye-u?NC!@nkx&BlG1H^sHFMpy@pWIvQ&nV6a-D4^LN&T#y{-*6 z9g2{w8cF;Ul5`M3Il6!Gn*M&n9QjA?XPDN9YL_lN_2SwmbfHiJc6G)K{E}Qa2L&#` zGRWe9hOR($RkW1O_w`jHnm%9!)ADSAPimq+{Ulvig6ac@weNKeQZGORBzn{^XC+nzK%b;T>o0B(fnJy1$5T! zZ_TjYma;THc(yuHIk_mGi>dy#Ya?KQElf-hmaoq(TlS*l=ARD{u)8qr?B%TOXaiY? z*}=5It4t8E;oTLVn<@0gF3HDtio>L&ycMEx%!6NaD((b`BR^>0zGM^ld{N+5haA-< z7@Su55>=*dJ8o4xvQLX(j|P~2(xIo~2``kOxXvlI&;?dN;s z)>c(vABuw#h5L%8ze*S8A8p2g9d$dQqJ?o(qYox-u0F_< zrIS)FNyq9`#fuZs_m87cEes=ot5~S(Q!GFkPNM3yagn@Z zcQdQ;rixi5oQ7Il$Yrd^?cI5I#}?G3`)hfE$Vjyuv(XjzYYbdTo4M~qt+)P)h6 zWYAZV7@zKe+7Gq0NcZ6!KIM}q*brI`h7;6dNI@4E-RTv>Q}&biQQ*4JtBxx5-$lJt&Z<2DiSx{1|!onoxr1~F<>5vrmgKDjEDAJ@w zGk*7U>I5)RZy?lOc5BAhq{d@6L>mxpbsZpnyI!s{XS1|9EiFZCA48nyHVSx5Z){+qLObkt}V|infn-%m<>}4N9@(> z*Z0}PFQu8E>G@_y_&y53*rQNAbLe|k(V~qk6qzA%elNa(s$hOCxh4c4RNbF*(?e{x zf?6SqLI^fPvI*64O17x_VI^JlTX9ww^|*;c`FSq|hG>`@ zp-vvuI1Vfsu;cI{r*?ZN9ije=a2z|~yn6)&eNL*i3r0wgZKVifW-uVAE*Q<_po}gr zjiQ)FmR+Y5jM<>e+U_nN3;5nKu=W1$R7Nm$nbFVGLA#-h<~@|{ymT5WPtYu!dKeS~ zd_jFaP(R9X_1bb@bD}2jI!`*6e%rD6FWo>#fip1uz&?L|BoEOD-t<4&H0u4+?B~y) zKdyK6y+Qp;i}N7kJddMM$r1x`mpIh3W<%&13DY)RYo<^`Pe`P%5DCX#d2S55Ir@hb ziSi)&s(XRXU|0O#ngvvZGu)O1{uT9y!8qf&dpxAX*&h()Jz%AO(fkqi_fWUairSb^ z=AtLegw&S+Sz+3*a5oyyJqRad8@I^gTohza?md(gMDPsUd&PU_fe-b7&WI0T_R#Nf ze1dBU(tukAAcdkQW?l0WI^%G%8Dgr zIoA*fKAe*e4eU$hbrBP@Yb@KyK$THkC29~LSMtIP{K_YbGo1=+H?pGx#19Xi)~?7U zHOjv_LAUn{oCs0*b>o*vD9elUcAx7{T&cl%GCn?zTDBfMctF3qqY_{A8O(Sg%X*tm zOa^%0*Xx)FzWNa58OroPmW<_|XsxU$8Xj^y9^qkObFDC}vd5nT6sfVsQn|BQQ~% zxGUXraSo?JQT*UT6ch!b7@e3+s~9aYr)YqkU)eJkE^N`+(8e#Fu@|l)C5Q#fIm3U5 z^F$d?@;gRuL^Vw{7`h9cLUciL#op}rsnAUax_mFrE5nXasx0QMfZdkWnZp1qh`YwX z2d+r-oJvLl1vAyy#4YLNcf@?`)VP|*F6qKW?4!8Q$SiNBO*TUO05}7RHBHwk6QHgq11FJu|PDPGOrZ|PXaE1pnOjA+zO>isCETeW)hl( z^^5FsnCzODFYmFR`Y^AA1Q7yw*IzfE28(4Kf$HR;Ng@h}F6}tL^1FAWjys2Vf^0gT zYx}dmb_Qk;f)WXBh{PKwA7}%M$|{`<6{4%lb$Gn+&z!W8R^C9F zg_DTeRIk;I6rk%LzK`Tj)P#vyoI%p{ow%BWCTr18`zFfI6V`m6935eI8*V~&?A2dT zqX)%kgf-TX+(*4Ph|;ewSN-2ak>Wd)q6DWL9d&fRAu7#<)GuRaTv`-LkJ8qw#U{@)#8f2|QrS~{Qh*@mO`sSqj zKdkx=W3W{YKMQp$oHw9d7oLX}o;BD)gp#6%?q309w>u}8*z&N^Jryh*W zEFpTP@dk}oV&A=6?WO`;yNtqD{3yP|QnIS)$xE?R zdDkKw0fLFT8DC8!(RdPirC{859J0M3W7(O~RdII;Iob=ts*w-Mf4EISN?4IWnq10k z0xme{sy0NOB zxaA+5FaF)1lw^KPeP|8Cz0&GusmD(>R|{8*`{cqLSmgla0Dazg?_G{PGk4wE!`Qa2 zk$32wcZm}1ZC}5-_b;#Nf|>a_PcY7Z_+HJ)cK>2fpR9?DZ}GCUvHK`gnzMv`PS&Y7 zZx)SEG&0xF%ubg`cHA;!zWc?;lg9nw^HR|H!A{IPu+QXY>M%fNbt9ClM~x>)oJnPV z32rkZ|M$noq%%?jlhru5%Ul%(Rx6%R@gL<>Qp!Q68VqAgL<^_irXi{NQF6D!f4eAM z4~gwdHC+ql#`IiKru~lx_u#!eZ%+97k&^!MfIm`KL{(uh1OUa|%F4>pVeDXK?*ynS z>cfsFq?!M>iMv_Ga&srr2$vdeB;=|@MtsEIMtUqnO&)O+8sde@)afUqT!E}D`}1&> zO@SZWrVxM%cUWLpglX~Wif9%C8ij?tyjXb;66hs*fTm6W^91ihD0&ie)EYYc>L_S} zO86i|_w~%}}2J+@B|QRl9pn`s7QGc#om_3e!0_HD~U?>!mF0P2e@l`;*Y7G}yD1{M=&>e6Kh~ zG))xV5DmwU3+x~#K|l6W*br=xzGhu3)hk*89Ilvf<>SvI8U`1T&0ci{&;}l|hIC1@ zfm<1Xe|lpDk0`de34z~=3{S~jGlI7Ym=9k#TXxTF24EoMjv&#xQlX9_C69y+A%X>W z7gJA9(G)NJPW9|$os{Pp)FYw(>U6ZMz=@ONh-`%37&luqumou$06syUi}}+l4AT;h zKiyR$?HMFxq=gib`W5Mn7l9?^h<6nkD6z}4vD7T=r3NCy>i?);n^A_^9vJU?Y)u-u$$eF+5 z+c3DrZPO5>M~Q1wfo)Nx8Ha@RzYxVh)x=^J4XEf6KZWlpNfL7ohRhO!IxT~knN$^g&zLF5e(wy`j@%;Z4WyeZy0Ih0}|{OePxy`Wv!P6 zEG2dY1_=3CG(JPx+Sk|jT+Y?YxFVNIL;^0k9Anbbc~oBAJH*P%n1N(pM*q0IhWk=e zy~d?K8H(>p&9pmoLDvkzo~RULa-f@U#n~Wi#BHKu_p}~tt!NOP@gW}t4~2mkZ8`9G zb%IVJGO6a@N1#9BT&DGTAObto{e6m-k)+X(;DyyTr4mHOekzKbUi(T2?152YU&rI~ zDB zDPpSv3dxVwUIo#(G8im@oMQYCH=}+ef;DSM#k9Xw&k*P3J+q)Q#?q%oI}Vw09Y#HO zYmA_6hYWo`h_41y^Mroy<)FPl(rvWk@T2NBNyN;Fa5%{Vf`L_LfdQbQ!a+lMj@{xe z9nR{E0lm(r8o>g|S!So{0c@;3|j_y4wFu@p7xaxM*&=;$WO~dm3?0dQ41D( zwq%D3FDME01_0eAzvCdvpZGwAQ=NMwvZrnyFXtw(_^+NB;(~2~o2|%z+TLY>wUK1i z$14F^4Wip+_dhT@ZK}_)388!}V=U)dY=s(3;^wB*%i5n5EH1&<*A+wBkVCT5-2wQz z)g(xc8kk!-1q$x2A#y&m(6-ngn`O5{w%mt}Aa?5Mp4sey403(O3*@(MY=5R2iH4JC z=;$C{0q*!P0ng`NZ7&dv`iwDQH6^2Eaa?|zDhEenP|E>);qRD1Z?x2|wNubFfb1M} z<&>j7av()dbtLQ2I~!23Eq;n6kdY)fWkaT*{N62-4=?RC&XWEE0^adzOS9XN<2(vG zXCW%UFkxEmzoWnnjqXHc+xQ{c;I!zFSctNd-@{&2{|G)*BvP6E&ZD*Te8niR56Y41 zu!55s0Qtvf;G^pav7%;Th%?Y_jx{gB{YHuG4T-Am7VjA9-+G|B(@R=y+aB5v_je^c zY0)(S%hO#IS|RK;3q`@)uiCI^Gq%xFw{9;f(v=g52_;|f2`=M9o4q2rM2p@ow14UlR zsk$5XE!hSbI(07eNK!je0Xjj28VTYy^e|5`KbE!K&c#V1RO{0B&8}gGWERqL}YgTox@?RyuDF)m7`x0f0Z} zG+kK#ixq{pFT~88Ds+c#_P9Ix20(Ez0nZLH;+i0Ukjp? zzrU4WR)H$}1r$F&e40KqsO-9U5 z*wEc*NHQOt@w$IuS^MQ2{+aAQ0|np2RpWI>HD=<0Gnl`*7U4l5}s*+WF{7(_bZ zo75+hS7#Aa^O_&v4;g=6@Tq_H3<_un8@0|Km$#*~|HqbYH4lf!687|DXgM17g4iDd zqG%c-o$)e75#z-fZTR-o#@+10&laysS?I(M0`jr{97PtfG`fpl0yCg@qM_>Wk|v1R z;|}SSNGidJfL{^^4nu$-T(o*JF~DcMfdqv{VuZ3;a?UFFL+V@rXJx;xI}ncsd5VBk z!b|-;)VBWZTBB+`^p?;y<`4mCOvNJzh*Rd?--=fYjeu3`*3x?5S0-9s-hqKZ9AOcr z-k71u)(B1DX*lCQKm4Nh5lroypWm2FV7;s1HUo(;#Z_o;WkdEN7I?2LanPGo4#RjM zN_^&xX*DUrD)pDP+k`b0O9t%c+(AbzJsIaw^RxPoq#23PPx3RQQkUER3zMR9LK7;P zMcsfje0KEireBM>ffu8EG4`T9X}XFZ5VABF0V77MGw0bQWAhCCzowWVerkeFGCTqw zxzteRL4pb$dK9^M6ZWsf*i#~PN(0P>nAbn^_x(7INIgVIiO?x!6>d?uGYJR-IPopC zBB9dorsig7jfi^v`ZX%6{~fUn8DTmn+^BdDkoe z{18_;f&QyV2$Dx_VJKAF2lJ@uY~-{GbAZM`)3f95@I_9X&3ITm&?F%xmEb{G6M9CyB8*Ibv7S*yphj=!cWj4 z_Df1|o2z0E3smJc{U=WzuHvgmy1@zgu`y?$PfccaPHAT)2%|-FvA4 z94T&SfTM>y2wSj8Ka^}TIFLOXwMcr=SgtfN@ESH3E}5j0jCH2h%p+t>cBAHV7y^YP z4rYGh2}ksu`CzaZE@^)(D|>PLgMD2e*S)HzsSK6Rz!CKoj#AirK|BpvQ{Jo4x?2h! zAwn2v3Y&T3vtQN_cg;eh&)I50((COeU4Ff4z6Q||i8TzB{;%)5?+-h)Pw24YwFeJ> zYW5Q9ydj)WBv>t{&sbMg0%FVQ5 zH(Xpeg2;z9cd5g3epm{cPU68|1dh^zat9bA#M>*i(w0O#g+e3*3M1Uw9vIssTTeOZ`ROuLX_*eR^j)3Xz?p=RN zx^y^HEl@Ow!}QYXvA?dF>vLuqTv6PSaD0YY2#{LJ=jB)R*>N8|hBlDc~99y%^FLXCuAUH*#t6I?Y55N?- zRRHcCpsN^mTy;YXdaOHdnm+l-!?-(#89RkQ`X}kPWXdF2x!IZTI3m)oMM0hpm=*?< zhxi)lS)?uatG=*c>LbQe_8eiJbzdtPhPyJA?v~mpxa~jNzGRf2J3ZWVobW*#oLF=O zA57dp)9~=(v2L5e$v|fbx6M?5zXVsyQeQ`0i|!^Rd}rQQ!&+7}+$_l#G`w}$zPH4F za5I;$WM@Z|^m-m2ro7Ccla?wfR8&^!4Xps;HBwzL$soK5ieph^kIo%+1kAP8{Gb;! z+*;zKOgr^S>lIuC3daqDO3PK@J~xIT7T^sNfZ$y%i}T7FkRD(pZU~1F{|nHOhpM2P z4ZS>?nxL|4W!G^G7^SwjAP42Ao~xk*B7Wj4JRuvtRv!$4;jfxOZX-C z`Ue%s&Rlo@F~7NNv)AIKrq4Oq>;igav++_tikC#XD&J|Gg#hDfMdbr9=_!GC3^^*0LVQM zH_VrCFa5pJmWYZtPDM)4OCtyyaRRvb^cNK5ygHrjLh^1l`8*R<{3pwu=>y5eU+RyN zf<@$bLr?0jR(xdh_4S@-r^2_uux%uBEAX_kH#&O-J{2=!o$~J!vkK?aMp4)MinC|) zHkU`w$ZyrnVn4WQaFTH|H3cu`uMwsi-={YC=%~a z2T@{eGx~O;PG|Xw73WC)^z#>0FX`*%_*0Y<@yG3>JLVH{59g_^$4H}pURF)l@9CY- zuij4L%z!h^@g>}%3#+q;Yj056Y`8e6E>CmZ4jRNAL_+1vbA?#Wp_rk9!H1hUcAv(_ z>Pk==JdxTF-zDy>ws^MsBNfN_gL`xPv9XkI6(2*Z?;Ud%a}jDv|X+>`f*p)gxAlR!}VHq9~4_b zwR6)JC@C4zGqMp&(! zdL8-WuYcFic5F1^VSQNC!Ld^^w@WzFEZ_;Me>!C(?@E=52wrNCuQ@SFW^xhiByC2~ z+EQSpviU5-Qes;F@p)6+lg}mj)hzXLn^@a5W~MP`K%@6y%c>tQ!qb*Jt9@aVShDeq z7L<0)VjWK+#;+Y}{|P68%CPd5UMMMdrykCilww@?45Iz%Kj!YarKEU}ujk+Xwu z3dH>%*xKlZPW$v9L`R4@ZTIgVW04zTTA<@}CSZh^Buc8x;on3Hnkd>iubihqzrFKo zK29A}S6%1-DDRV4$*dZg_Vj_){lQ=B+pCDYs{)3|ZXPk41b?ZI$caT3W&7~=IT3KEzGCj-cyL!*t z-q4$(@dp5Ms^jfyzpz4S;@AcYgqZ3)tB!y`!td|@1%4^~UYotX^=^_-h=W!tKp5GY z&p!OT`O{kszv?4FSMIqViCXz(;j$_u0(Ei2(fw<))T=Oc)z1Y5v+HaUvunV%Z^o9)`@%Av#D?gq$QHCu*?)p3 zDIgHbjTF7P_`_f#Q4;c}{NYin%AUe+DMAUaoN!3bu+wQ#d-hk$^Wt)M%f>6~uUq+4 z#qMnAc?5XhdqQ?fM~_ICV$0&EGwB3nR=F4T`oW;*TYKcZEE^YU-nWN;@yx?k=Ym}Q z4sR*1zS?iQ-kf@MyC-7%av2_>mfT+=-Jd@EDfIeXiEz|>IEA3{aB1c^MzsLP;RReb zOS$2KK;c)Ar=a2?L2$S>Oua5!-TmpUSH|_@TBd$xqwT?iNJ&2*Wa_!nQS65s1Y`QG z=T7%+pELKg^Z8qbH*a$-+2_ukD9Un&E38ut_Tmk_h&i|iL0D?j`T2Y$6W-E4i*KkM z9EiKUGUC1E_RFb+%q*j+BbFbA<8h&S^adZX*FSv>|Do%dS=m!ruYu#EbvWu$Z%@UR z)}0G}ePVb`_MG}~%K5L;oX>v!M1_jHYAYCE_g%h!lcM_a_hfK|3!1qQ;vjQF;K!kC z@2Q7=Z#+|TEhrTp4g_MUdEYFJ1h1-n8?$U=D0AzD)tkI{F%J!|vAhL!)u&B;C<>p> zGXw{?@_32)zwM#spKA4rZWQy*vA&yQho(?@KncQ=s@*%5kQy<8#1y4UkXgwwh;7%O z>B=B~BI7bO#Ku}f9uEAopg!!h>6sdh;Scv-h>EVn^K?O;lEta$@xmfSJ+>4k+!hkg*5!phFNZ@~*dUifqZ zI^^RlMuEGK$u6bXVYc2*lf9xNpdRvP&K#@i^6zGkf4J;Yn~;^0F*{?S|G49}*QI6B zaew*8?Yt@_cY-~#P%*-jgsl(h(Yy2n}R-_O=aP!aPO zz3lR-qU>PF$JnzvM`x^fao+qu!5g`U`d1fLra8`NuTr7HZeOq?yOEWegOC>qIg_4m zWSp7+^W>?@L=YQ-15K65+KUxZwKi@%Ua@b1zDCwMF;};3YQtql_2h_=`dRaB<7zHF zymaa=+z`;P{0TiEt-u@za$Uc1{>1`Sy*VtW8zp+GGN$fr4~W)*Dn-&^?SFl6H;7Zi zCRD|rlycv{d9u%io zlwx(h>D_}(y-kIcTkyF=aRe4g!(I}<;p(X@M`oL##UIOp>a&0pfYk%u(rfh=SNBz! z)jV{yQtSKotxNaNgRZAN@p;^WcweIFtBs^69cH zLdtJiZF8_R3{{!QlD9=l4srV8grvZDPojmk@Pw33$$7bKB`?dpuDa%3`pwm0o~|Zp zJ>84CKGnR*Ns^tYDf>32#^;rB%lGqo{0UF|$I#KiszxC-#kN0dL(Z+Dj15W1g9uic zjYK4hNukvrA%U4nmETRyv}5N|LiU0yoQg$}^{P?5W09$;!!BJA;uS$?uufbaXiS zPJ)X4f5EY)+6E$ZjnMiR(-(i2x7c3LXOAUMdF@#m?CslEt0c$y`NhU8y*@VLRbbV+ zF1NddRLo5(+?&SPKzM zh1gZT-4mXshTFe?74zcyrjl;U>KkguLehI4ZYnt!y=q|oe^?epNKLB8?5+*47tq~w zyF*9yBx?U~=NB;Ma%}*+D0$PN`_|is5+?johlgG4$3B)PK0qN!`Qk-Nw!Xes$2CcGxQ05v zB<%l+>RtG46Xni;(9S)Cc`98=;bW_a7Zb&!q;^C25z2DKVeeI?qt`XRkLI%*8lJ;X zM~8X@h6!apl2v0buPLvk8Oi^Vr^NRak-%y_>7GYMSCNiWQDJV!NDnfB3vMHo89-jR z1Ul0qhHlX>at~(s2kbhn>hA<97FpxT+ zz*T#{4^?p#*nzAG(&mLcgL}p?-y#nkwn(sev#;p9sB0&-3J+q&yl-6$N*shidOgH% z(40tS`>qCaG3rgwZKT32`DVsDlBe4B!fafgsKjWyH9XKE2${0Rsj(aQ(SfV!B;a1{ z+DM>-IR&|tY@tmRmIl>*3G41{JDh(_BtC20ZL;nC{J9|UGJ87Z55iK7L{k54H`10u z^iDL6%|&Xi03EVV3)TgY7&)T#q?rer>wZUS;{?YZ+(M}~zz6?^yGi-yNc%UO8zfHClIO*#PPFh6R|DK`RsG0wvXTGEev?7vArFoJsS?$i3$|I0TA}kDsB9vBSdywdx#b&2 zpFv!gG8D9$w^u&GvfFsE9bD7`XC6nN^raw=OBq}tB20vrw3wZoS>kb!MsN1Ujg_c* zAmLU*foN6#baL)UwiA_u$&o${*bYhgDeJ=80j54 zbcpok5aNw9eUuBok({v6;j%2k!GcWPj`lVcE&V^%Z7)#k(cS<1i%GhihekM);mLBM%FM9vSp*5qc4c&U$`X!^)&g=C z&x9u5y(cuSuyYY0l_I5yc;o9d=Z?>sMP|>UD#P$T5qpH5a`}-JdsKvxX<;PYp2+~2 zwDJTToH?;6)g+=vJd|iaE@-F%(~Q27AD0647a}}!9y&VVo`h{rR=3KR3s%B+|DC_a zFe$tzgL1Clxv?=j7#j%A0byI=q3pP!sMszXC}T>E3n^jsufFk@^mWSya% znFL{p;01_?{zoe8nTZ6f6{l!+yXz6+SV~~I7y&^I#XL1@Nr>>XRE-C=cc-%R{S$L& zz!f4n&FDdN*EB-~VOeIjD=XqU^VxX-Vr$*in$g2J2y(OLONY4}WTrB(c2Q*?cwzJo z#0LCGQHg5g7AO+uY-i*Dl4pWSh{V6_x(ydM;i5@v&+Rlb6Gzay2y?KRVKS`fEi`c% z{GHk5Yx@)zleqRAX(GH(-eQaJlT@h%-WM&8DeC9ZFcX7Pno*|a2bs@I;A62j(PD+O zVlC389zxRMW%r?HEnIVeEqobObwOVCGpR_S#u7$u>F}Nysh6>J)e){!Be?-&-S-&X zOlll4f4x9&c|K~r=A-CWhyZseyBjMFiI&_&B2R#bPT;Du$VF2VT-Y6|ObE+%{XxR@ zsDHdnE(`_SBtZppB=}FXCZi1t+Anr|<}r#sBWPJM@wctph|7Y1z?WeZZm?NQxx@aA3jNZ4P@t?8dnK6SmH zD%O&2TTPG`0;1bRRk&ekv;h%69vEXrsDxyFY?5``E(9T?`$e)< zppRjJzm9hKd;1^_97BGZ(d6bV8JQ(qHQjJaB#(_E9~eV2?(WVGF5 z0^;S6TI9^WSd5hY;V6=>>|ksOQJR`6iGy$~AZ^Xek^Br+;KtK2{T^v}!td^6L|Fy# zsG&pS?|>i)MUydSV(M;!q4Ml)48Pn+-Q35u;e<&a3No3G#(ezYg!mlmHACcE5~0ah z+Sp*4R53cm#>*AjEJAQ{s__=V8S&mjrzZYga5w1kF0xu|Uv zOuu0%k};3uNi7l|?>v2F?GZgtNOhN=^HYq|>lY#nf50b%-8L0f)yUhom!lEDMIg+{ zSt`rokhCsXxr`xV2QC1U1I76gK1>D&)jmEc#H+99^A^q&^XPm`nh)cQ+=444gWw-d!d!as3Z`ZT(D6_ zi(%J4vptsyzRW1X<+d0nU4Kza4ClxvJ}#1|3WHsM`#;jKK9Kcre7Lp#TeHDm373Y2 zQ~(#Q0A)B?jP#0XA}RKBmTtyi;%3q}kF??Wh0|v*$89 z^y~yeJw?F;kq~Iu{nWi?q*}N|Z^61kOoJqr8#pEu+69;gBtSz}KSS-GZWch|bmho; zK@r^|KJ-kh4qXSCfqbk<3;)&<_f%N@s38EL4n>tYNyig>_JAydU8baPo4Gv5RtzxS z1Wql`#yA-aZ=HBT#$4`6u`vPKkTx}Bgo*%!_oxy%*Nll2fS z*ZBM2`uXF3m|uzsqQ$S zW>P}AHmu7>c#`D|Y_i;Yi?kxi0J+F5vYyR4=5Tll-{-|>dfsKl#oy0X>ts)ov%_vj zkdaS4(Nj(N|H9tr(cxZmp0xMI3rQb`Yna+S>4j8wHaKz@5MP80WXR9Y2m8AA`@dhy z+0{K2?C^jly~79>=!T+D&o6>_H4^r_hiXn8PSbG*wRh=b0xcNVVy z{eS=y3#T&gD@HvLDS?3ATAWth7dLsZIwTl#Ddw1VBmfM10dw5- z;e$`5>sFc#Z3V&*8<`}K2tWEa^U>#!=PA_mT^L{JVfwK1U(6(>?}M|NmJ{sl`}XZ4 znfJ62L`G)AQzp=C3v-zF1{s|Pm z!%;spPL*}}g|q+@@arZinhF2#8w5L4C}*&BB5L;;8%M&=CJ?o3C0aGZ2vS^>588?m z0T7_G`s_~|P63xaVtd%DJVq;yjN>I*cYf1APSOG zFI+{t>^LF4(3I7vY1lqL%fNXxOLMeSfX+WM4o$@)@RE>0V_=sV?`PVcA!Yi46KtXB zm;DM1RfoB^vUoE}gJhKHEzAM7Jwn6YgBZSw1B9r%JE>t#jzINb1<<_3U@g)X=M&|V zrZsSBJ7kFq)FakB!BOo{74|gLxjGmw4_NJ+jg5_ogN>Ok$%>{j{(%Rt|3c?k8PEqo z7wB4fn0q9&{dY0i!W~n0cXtH^g{OR4>%n|hS&|PVK&)Lz-e^D~WPKf&7kO9Sb*oKSHclvl$D&{n<*II11 zaq^T+OfjW#Wml{}#?RZK(g!y7QB-)_Va%=&`{ps&cL(l~Thy3UDAc{TlyXR+l<*kK zXoc`jF~$~B2|zKg!7Z0G6Kwx{+YRt`@mAmW5>>+mk*||H==UN z2ZNr4H)2S;$~fno$QacW*pi^7Sb{DZgS7&Ri+IkUO9BZi&6U<8COf7C=DU36rrJg3 zOOj{B8`5VB4pgv85iokpM7^t!+!z=Bu1lnTl_aU)tR|q+&&(i} zA??NEoGYrX7W4J>O#rJTVild@Ne?vP2B$Glkt3B;WbiOhrlqwtGB#G`GiCZURV+l{ z@~`;BVBQt57myoCI{nO5Ww4aK-YsQtD#A!->=BHKl!MwO{PHRK%NJ~rS=U9TkpK^f zzx|6T(}cIH$neu$#QRM9ZlF?()P3fxGB9U8tLWC-RJz5MaAJMn6vV!ri2}d6#SK$U9?r*y)u$pOQH+#H`*#D$qCGL%M zZ2Zjj;hdqb_kgB|@(Bx!lwS%yNre*-rW72RHNm`YXaA!DPq2#|St3^vi~U@#j-I65 z6Z$;lPtHfZBk_xn6nO$A0kB!)jN+uv!*J31@sW=I%kV+xX1;yIk#2+k#ETRx0!K;k z^0OF1Z4J$YI&XY#J&M5nPUaQ}7=A5df7EIRgY{LS;D}bC?;5A9ZfPB>o6f%x)(WoC zOzK}hB@yG%4qbN0iBaEeSFA04=m@=xHEKjNnNM6lKR?m~?Gw-kcvIj#-edmnUtbG> z$gl4OB6=pm>Y%0rg4hw)v$V20k4V%Jf#s^kNU#34DEc5hfdR%H_V)JapAN|i#r0q~ zgNas4wfuaD$xf6aaZiVrNPHI@`T*5N4jaB_qyc^8X;5bGO(27eS*98pQz)P(P6ln$ zWKZ-WDna_igZ=h~4Wvs~55GwfxaU}(6lvi)9)%*u9BlCd4ni7okim5VtiGpDZ_njd zI4uV1&^epmQLCQ!DsHE5HeHtX#p`y&d9ikprcIRflA?Ym=?gWQ()!Yz8z zN^-V0)m@Np{o7}MG5sk*>z@hP1K~75cZyR6B&v(a1;Q@qs%+E{`e19dnXlJ4I5=!1 z-desPBJ1+_TQY25(giZ^+KIzpSzRckTn~7wyY-$S-tW)kZ08N;>-wZ+#uXLEW014C zeoGdi%4*WlcQPQ?6hUFI8Z*91HSD?{r7rWo*Db5;Z+mBssoF&q6~qOmgos#I=0Jw*q@}bJ zT~>+6O&wSq;b9~)1a=}k@<=G3m9QU^KDoItBY=zwn0n@dm`SJ*$!cgo|8R$P-VC%? z5#_etXYD~4Il>p3(AfbI21u9lC2cpz}!F(xC7X_ojy9T}m$A65jE zpM9u6UVP%&PePznF+gFF&=%5GXws{OPdksO)k55|p)D{24oT^ttw|DDfO;$B?5T9P zpY4k%^1He_F>x@NiqUekPNcwvjJ@YS>=YUaqJMMm|8y3v_VNsOXF~HRF4_ztoV1cY zZ}e|O!y_e303d0FM@GtfX4G^tuXIhnLj7y>cC-S~r+f&JTe#Btek3bafPO#HQC+t0 zvE#>mkW31$(1D@s!h@7OK`crTyN{Ne6i9~3 zOgh90dFBn~ywLxA&Yl)?TZU5*C(ujOGzd4xApOjtT&2u8z!(z9`=7CT))p%9o5bpg z4*t*goPplE#*KjrW9; z2#tQYM$y~1qLWo?mSE+%D$IhSYz&{D(JV+t51m4VNO+Dx!kdV25w9lc<3^iAWSX%M z;JfHx4rhfUrpDx=N9Prxv^C9eJ0y^{!j;4sa0THA!m52l=A!Wmk4gAoMem^ny`d%8 zN&;3w>0507QjR2&PTC+XpbQS#piAB0>nIHp9{-=-D%J{Gl(6SjBwHIRC@BFh&15be z8H6Q3hX^K&v+E*{E1icl`B0pJ^j9X$(gi%wU>*^)wG-uO0}It(d^uKM=5Q-C?AG7K zJ1mtA&;7V8V%c=L6?fy#+sQqTJN4JX$6tQ9@+)T-`{GHNz_0#%ed3PjfteMX&n&FS z*<`Y7*@1{XtA31HRru={*8jv79&PL#sZYrI`A@d%t$~+n;rb^UJsUlvUMV(ft_kkI zu;MlB(RSyo7w}Y=x@}%+f{LFCl(gs+&_u~v_}yQmE{?^T!ZtE8@s zf%jWqccz%}ZnCD$;L}_-YYo^uI(J9+`s5O|l!Pwt-@kt;j7b?-bKomdC(U9FKO#yG z>C)XS;_EOz)OZ9Guv(qrf(O`hNdF`jtJgduBSRUn1PtDZzl5(N*Th*5O?xMOeKAe( zn*C!P-kPMF^wt8)YY+~Eu_8EhQ-Clt*8fX67dT$Wpa&>p_ODKXC6)G149l&;vtOXk zKK7a?9exdQ-N3M-^+cOP&ZvnS!wa)HlBB|s#cDcDhPlI`6@i1JTQlmQO)g_<;^j#P z|C&B{FiIwmK$*FDM4=dAyI|sUx+%~~XIL^j{D`vfm*_ZS;}Hfg zJ~a0rbQ*}8^u|s6HyV_a9;2l4AF#>;>w+f(r{d z9a=)*ob;oC6ERm`O%4&YRqBy;PRMm5MW8Z5Bs2J)k7Pe|lgw`tSO7X)gSWc@y-avW zSJKgE484VyXh2RSj!=X^WYr@t5~A>ESSr3;_{u`&m21FU;Ig_8gs8xj^R)obB$QR8 z3FffVK$UYNP#ZhENXF>NB$rQy^)^0a?ukI$tu{mazPO}hr`y(gQo%X|ZP-YNixJX% zTrBrNzq?Qfic3IQqDI~fGTkRw#$HN8a2=PN_YtBLgakf?n2dRX8H%Pe1D?p#T>(Z zkLRoPyd+qQ+?joF4-8xiW1wBQj7A;}CuGMqAgP4hVD&)ezy~Dh$_l_8VMp)Mvk6^3 zULqS}7d}lU{TU(071><7-RL%vEOtG^d7eQkpFovuP*~9x=}?Cq8gR=52|vAs`84UZ zVF3p;K}N(GE?~$)mJn%k>5K0D1NO;BSV*gJgGAMaBOvmBFME9#0WBz`I)G;La$kFR z$Qse{NQjIPMf6b=N05{d3?_SzFi|C&{Un$>3EKKJu|%F0_gW7mTaND%Hz0UK!lduc z<)ilN8&OfxIg9nGuEKCGZB_Y~$V&YO)AhBeH0J%72G`5!i2O;fXl-Yrk z>F6K5MtY(R!Hu|AKvgVTHchBslcFazv@9lA0}lxr?Oqe^aBw<($+jiL_6F5KQf?-_ zz~s@9_WxS9flwcs>p^8EvX?}wNgzwO)MYLW$CUToFJudF7vPEYKFE|sw;3UNaZ>0V zcFrfY&&o(zCWj`OWblkkx|aDgnN(jV$~Qvuz-zmHC@>@>Jb4{0&#M0<4tUamk4TF* zZvsj20D>$+yy=|JEPaaDyF9dL5dBe9LsUSP3-yg!>+kp|n|_wm;T5P4?8h5>prfcB z33Y^!A%d1X>D%ZNd)K03t^oB(Pw>JhI8+hJd+{09uC=DAH2n`aRY^4vv6|$Uc)gSq6(Z!c!vtlM1;dY=FE=cmS zfHP&QO3=$LBYAQKac+?-Bqb;n$ov%w{n{nWsh^WMdx}-@% zz5Pic1fhLICO`PB&0i*@EKxoTIk_XqCcBL03CV=l!RHtE=wk>ZN030`vwis#!;Z{= z2B?hUlahDujEMyF9_>;_j*Sq3A`bH*9#MpJn`^`)Lue2|wG)r0E)dEVwm6Mvz7Npx z#G_HQhMWbCj*dg{841r)NYg5NsiU)VV|1k7b9nYC@`qn}=+hBRV8(7v|B;yM`=W!a~VV0K_kZ7tI!@@6Q)TN}dILMdToP8vO~2 zT^lY=+F>38o66sx%p4jj@P;N|@3T_PlTrYrk&W>5k0U8;m*r-X~xwXjMxd$tI?!QMx6<0a$eB@V26`YHyRj@5Gcgk z!>@O{@yX}pPK2MU;Zi&do(~R@9RW@``*K*}dUy!>Yu3%!10B0dhaUheKhAvfEQpsm z`sYs#$L5R ziNH;13V?+_($a1_$||nG4MsGcPCsk7mv2fd*y@?B}=>~-d>R_4M(|D#liC}9n5)ItaZNpFPw>sYbfy6L$l$blJI~9W_RHQ#)d@c!^~RP0!VYkIiMTQeTH*2`m_ zI^*SKWgnoNQ5=nA83Y`_iBD$n{8Xh>+jgmouRETU&nT%FiuE%wpg1YfF)^t>bu~wE z^TcB6l5Xn6G|HwMxGg#PIMvaRHPXP1pBTR)$IDlh-WetJ&NlpbiizVby=`ybDRt}? zK*(h&#ob(d`&zluVbOm+)e?`*H}{t;(75X09Av7+#eB_G`*wvISy{!{SX3SgYS9hd zr*4zwS^Re2;gy2dyrMkF5k-Nkw~2W@{ARAW>kYTjtckhH0?scR;17^x^bAtls^bvQ=N^Jq%`_T#ipZ}%a3~#uJ_6}sfLE~PR`!l zSi0GFYxP*Wg}-`0m2%9%o|gCel0G zdq^kIa=YWYh05|3)+sA2Du1u;TP-L2wQV@TK+YGoIp=&Y)}6{LjnRq!^N$Z7g195? zG5OP;oHmlH4|A7t?=T5HJNJb(AeGt7QiW?EX|IbhLLn5ZhiizYKl7W7)RzL7+Sl6lLj-EHR?Wc2Q6HZ0u# zB1n2-tfo{upuc~GX{A+s!&rxYe}#0#i5H2B1aDZTLwq*utVA!*?(8YNDrcj{fvott zO%dKLpYHx!x9NNUiaap>)Tu&psPJ@v=fFI90fH-Bq)vT3HH~{(T(3PX{+saj^tF34 zQhQe%*x82@sqF}z81!#>(3Y$-YkfrOMYPf4;`}y(^GDcH7Fqcx+rgCz z!_C`8t8d>37}Eudre`-CEsNJFh^Sm2PVX<$9*MRnf7Wu}KQ8~?9Bt!~kHhEiBF^`G z;5~X^74z-~%wR`^ac|;{-k&2bRaZ9@X~o!&^=VsKB{qifewl8VBG;bjK2M44AU!x= zk+(Lp?5i~gW*&{bp8MYXPNAKPmi)ovj+<=~x`xiknD%v_LUBiDqu+Y2Zz|0*wl%4Gik4iwFQ!WvAK&gIj$g7^o{>?i zu7?aAtKk;snRO@^yE|+pskID&_4oIN#phFp#A` zKHwd5pzEykR@pD7U2hE6w@JON5*u4O(?@*5#c^qUTw0*%bn*qZ_dPG`#Jc8|^x!$> zcOnwZTPw`2C8->2U0xK$ZnEz|`9 zOc#sJch0)18@7L{&-5jiJUUNmU+p+fp;@QYy2_Tpf9vBXyjlw_UOx5RUjML-Yo}`C zp>e$7V-+daIQh4D2fk5jy*ESUQBVGg-RW(~D@^}WzBo`uCoU^fW$9)Kw9NHr*~xo6 zU3$rqn`P9C*L>#~ufKbBtX})@$bgn@U)dSEJiL5xd8vqsM^XV;9_I_nEuVdL-y`{< z53*bgGHQrR3xEql_REm=%SHk3Y~%Hejl2JuEjiJv&(oZ+Pb}TdynWBzA$!NL%!!Vg z)+-luuY~B9D_6K3bF`{|YQo)KzivVg8*j3ua40>_FeU&M$3yWGThT}3Mp|0$&z!Qd z(#xy%?+L9);(GRLkAE-=K<DeR6ZTT^(DKD;|6!l2Q>n z{IFF2IDk?3BTjQxfWQ2BVTEGIa`1}~(s^oJb(pm8;-2N?Z zwSp>$Gt7~~3We9^{)S2`s~vVGLNF?_JY&*SagQz6oJ|uuGI?3!R`cX%iJQ7DAzNf^ z>nlzwb!|%ajd#@D;xx!k92<$bU{@@rZS4NY7k__mIR(_2iN_32w8rw#roX%P{}*xd03lZ2G?C1sKP0C1N*lt({_i&0R8ZTTIK= z(24NubP?j9%l;AfQ|aaIm{e!$wbOG3`%qs?OMk^xTRo@pe^USMD%`+rOJ145j|{G; zYULi9y{j?KDyD2*<@3R~DR(Z|{hRNlOvrTXJmG@}F6>KQHPq$LU1Y0k>he#oC}+Im zc2vM4doJ2@xK504k>`tsj25q>i(lYJz;@e={@uAn#@6WnwD}} zY-_%zCNlSi1}Drr@*|Vuz8#&&`tF4!@6x&#^7~%CetGQhv5?>psaL#JN(uw}OK#{c zHy%|yE3=uq;p%rG5l2I&1?ucrw&KdUMSm_o^Du$;^K+-9N~MARv3joNvD~eHjd+>8 zsJF`>*+1au9{!}qW!C=2?7r0MiT3%MDR{PMy$H};B;T)Xnr7eq#yfVikFASy;@+C# zzwy@!q9iqhqf+uV1jX_va^~?~2H&(|h3Ff<8#Z3(T&=9!TWA6C6q~ZzgbxE*wsT_B z#g8laT^j{{#vNyqdnsU#^wY0rQq)!Sz-tm(Jq`MFDYu3KqBq^~-aT@!7WbXXE@ zP@hFX4i1;%zNj671Pn3_K|wrkq&jG^<1WIdJclo>rs$G0voLhkwe6wrHcJ_2T^E&W zPu1@ZoGo!R>%PN}IB*}88Wpj*%G>)y(I%dCieyz44(#$%YH9$F8roZXF*hkVzABPm!+Am0YmjV5SMn5{*cB1>Rx%o%VLyP6S?aJC+E)dJMI<+MHHPW6gu3ZOR02 zY3MU$}2bJ~t&7LHwN5d^3R+KLnE3$eqScUt` z$p8~4Z9iqH3frt3wl&3u)XwhNsl18JI@={#qdP4B_sJ|BRDf7qwcct(NyqG42~LMj z)EJh!e&ARP_)EHrT}ZF%qhQdqoJs~#kIqm>sKMgtnTF;l0L`E^aCNNdS|ltsbUvB6 zka~-yT7g=y?&6EOq68(nS=yWZbmy;LgTle>?twbMD2;p(L ze@uZ_WZl019H?Jq5;9{^;HL~{c@LN3;vK0(( ztbXKiZ^Q6Uyn7Q{Y(xX^?%6{Q$;qOo(H)rL=p<#HD-k}pvHOLY6f~dnEX&zg`{7E&ua(*r4lEnNK8nCtEpJldqmU8%Qi zzOF2RbKy1WUdDr>!u2pSwqdZID;*A0Wi2}L;W=&vDAH9A5+Dz|qp5xC;cZ)2t0{*J2*)NnFM}HF*{M^2!48?BxJ3$h}uKZkDA?Em_w-BM@a>R8O&TL z9()CMFp20|dQc>%efZ@TC@#1Sm`MOy@3piJJ)*$LOMqq=D26@S$u*m`b@lr7!LCVP zfXlJCtb72%BC06j+9jjPsGR_vw8!>!=DR=`Sqoy}sGbj5A6flh@~ zPZZ8AMeBVhIv?}S%+Ypt$7Y*6i9cbcxYM5QJk7v33VxjG2QIti58vZ|mP)}4(X+CW zKfdQgZxN5Hk7b;0oWf5{e&}skz=szCsL3B7)YnS@&ABCyFTt3J-cDRPqO9m8CMG<` zGD#R8@`rtt2v%WN9$|~t_^JHtS{DZv3z0=N3bFhqHkLP#X;hWC-r0uy)C7j?hEo!G zVprAr_b?##?3uW--|g2E&1ZT4ir^kttQIc0pu*0>ueQ$8RD0CBBW2Z>+21 z+(kM;Lc)K!ci{AKWu~D zf-gwD3Dp%3pE0pf-p-}CdY=jqF@oyn6k!9u&>Gg_PNPUjC=KU{v7)~&P~*-Jy);84;O^d?cn^7= z?;YRO9kVtdTs%mv$AUVlJb%abZ4Y=CwWRDD12I^Be801^{9#>!q)7cB2!!3;mG02q zEOt$IahSt>3UW;*glySWme$+yRbpU3V6I$NLQNskRKkB3B5X6x0W|AN^0eT zAadJHR3jYrtE(!CL$9L(noG!CENXdwnlQf?t5=#(B8>InRmdfua?M8r&u^r~P+u&d zzIY67Ql4F&WP#ew)c>$t680Ixi^?t0*N3%EZ}!lGWCsPZ=8>n9M~+vaQRC)ahLgumnGkSyD^BnPUDu;ZrXOP28N z{l+^hErrBS=2ndr);?;wckF3-1(a?4vYt{s6ya72faVkCkqFsVxQIx)_4nK;$;vrg zR&lXX<4&rl-Q!+rr;8||8vS+F&^BW4tlWJGh~QR|R3IfCdTEAum~f}v^9#1FRvFb# zk#(X7(1mK&!y&FN@zEcLhReA9{dgt&Yxow=%;6o4lB7tF+-F33Tz*S`PLY;mLH=q2 z?|bhl^g9U9CKc!*+X*{0_|_&vk(depoF>r1!?t#5c}xIn5@aYo13I69EcLUW;Ru9` zjmu0WKhv4NS{I6ph-xzai#)GS^}lHt*uZ_EueUFZ`dsI~RP98DP0qQ7r*pZQqyLZ} zo`3uHj&X2S|AQM39?<;g^{Kq}XdAO@uJ+_=Ia$;Y;3h+87wCkmdv!OiA{r?ML5n>d z5)b@BUw`z>m>?;mu)zJ^PL6YzV%xD4S-QQtyO-Qx(vcv5v?a$!goo5|-QXQ~33-m? zT}t3hox;sYqK>PkNpU?R-SDu*-X8GQ;_VUDU?7W~_n#R0047MD2qzG5chZx9G(OIE zdQ*3U896K(!e5~}&Ag9D9J!(8P-ISuc>CSc_prmByR%g>dP zFdGgig2Wuj_1mT~Qejbg2NHlh2C8p&OB;oT=`IX(7-Aj7mlp89^z^T~FJr5CIHJX&FqNTxTC|-jmIL^bDBWS*+Z)=@q1o0~6j!1!vjvV4JbOA4 z-f;;VP^S>YavI^X*_(q^N-gKGqGis{H>}Um<}yIG-n&jy#M7d?^DnYNb|h6%V|pKg zOHV_2ZyF(mY_okyQbVJST%@4V?D4^&s<#byO1NjAaNFAnFxZh4s+wNX65pn$EJ+kF z{II9>lWx9-s)$B9b3IquE$YrRw5;X!zo8DYl=|7a@^n|8wzTAnAcEn8@4yM>EDf+P z*p`|KRg^qzYu6@cUV*>= zTcBs|U6Vn+UPv`T<}5hg*V2#_zAen{4%y+UH27IA?F``KAoZcWkJdkGD7W}1LtW`+ zs#R|Kk6@c3MTvby>iv;BEGPi2{CRRI6P3Dog1BE?vC^q43|G^XF3)*)48a?>0a?M9V|jGADA(ur4!H&?@6?LoT;$!K9P;SW*!Ww#d&3T(nOP@c?{}ENl+v$CG&qcIR<~_emPVJ_p zqeH|)$t?!A-?;(%)|2hkRn0bKw6WQm~A zD{4;e6u8Cs;C^ph(|y>J72z6dauu3OrRxy}71Y~ed9s>8y5eT!fYFW=NlbLyi3vzB z*8u=7mA>8QyFhvy;a%i;5S66=*FqSEgL``g3q_R5!Mtrvdgw`IxC{MI_#R4+lhD9%n7DBN-KG32Cup60L{(d zJP}6`JcLf6Xr5m{s32Os^^Ku;=kEhm%um);X{->L$RHLk@Yz6z{~ON^cb@0KQAP5f;d_r&-h0TRw77 z&RK5=6W~ZGI@_GcS0+fDa+r_W?jOfAZOHR^cO>JHR>I0e%?tMUb0=xc-x+<88#fRs@A7#h`^2PJOh-~@5 zv3op~cmF5>j%WeI^YSPOR4*`u;81pmcjF@YOHMT<61E)}?mdQJAiDrjsD}F`0kE+F zCp>KH)IMomD3nwDhZ|)FD?EH)#g+9|Zf+vbD=bIcPr2vijie-o0o+B`}+(~Z50Cjq|bMJqKPazsgF?@x?9I~Gp%Q=g9uc-mRf(!;{=XRo@j3!vH0Bc@G(&z zpcE6)XP1`qN{z^!SbBtyT!Tx7Jjb`fF!mWcSX{c7AW!)Z6zqS2n-$ae=*=8IIZZNr zjKG#WmTLaGy1AZdu47NpKse=bCJg^1E`Y6w>8rYk>lQ0`S$Sl{B(FuS@eZ*E^!02q9*|epi>zNGBL5V65P6Oj&8*S=s7sKVmCoPB z5ikP#0&*dRE*gEyBY~WFq7Q{G;az#md!9q#2;Y6!^0GXEH&_1>IpHc2l^tu{jFaFo zXx**O9c}TsVQxY&u-2B?VH}mj!`r$`!YPQXnl5(bn!R3PaFz^37M#>I6MtBY4p6u0 zeLkfQ7rX>TImACs@T8tICBW9Jgx+_{s|+Kl0e=c@Mis05Ud{$Q1M-1rHaAO`IzfKT@tF1Zm&vEQ+G zN5UH_?S1LJTzr|S^r&x?oGCk;Vnw1#a0c7#2P$!f2a5cqY9=yiVpF)TIl9W#g2J*| zW$MNCh=DbKT7`WZ)bv-sqroA%&-uiwy@W#$r%m$xI@(7lO_Az^cC1|AD}Y$2_<@=D=mE!)>&!Mz56zaCTR|*lCm3lEKHg`>{q0FW zQ2Gs;mlw9Ca;?U*EQbCh{JAil!tZHl28(Pd!VWdYUSvHy%E#A`DtMx8GeaC439IFN zq$G=C(-BP_m9ZRUR`#@!?CAlR+lLG;T5z$C+TW5$G24YVu|g;p-s>wl=eezME6HxW zH9G%j8111ihtnrDOP`o$z(@sg!j_}uP-<{Lg@U52%|38EJYo7qN4i^OrL8H=h;2{d zwkVv;v_Ki|9v?7@%#lf<_&a{^`ufeDfJGA&3shJ@z4s^7t`eUT&gYdrkU^78c8$`Z z`V7*uFv9N5>DewsR*&k&e$7=w5i3j_=|IWb>_=ba27`@ zQp2C$(j>rHVDl=mAtWgPrivthJV*g~Gwg-tFM)RV8pD_!obMS7<*6GHn&#wE`jTR8 z=19P4F+wW_BmI?Y!HgRY$ueG|*D>(89BJgQr&M>v+F~7@1w^@OwArUX!))DVcPpT% zwzCo0z2$ytyXF7>&3k&5(ABYbV(t=Pm+y`fHuMa7Sj;#hG-tXptCY@;h2PVTo@0Y&UAu zFA8N+B#1-k%LDa&ZC882PrtNONNW%A5wnIxN}k?p<}W){nOp$Zk^o|sgD8BD^c3SC zg)ie$?>Kx>IAh(rRVWs^f$RaZtd`?xZNu@Xw{1b3iW{4})wDaH#kD50wEp8&qlq?U z)VX~-@gOaY8yMJEP!_AR&p4*!P2k0=pDn-GVk&7;s96@aAVw0 zu}z)OWLvwn-^!Uew3bS=fk}3y^JGKpL||{tZ2}jCSp?ITS+rLoCk1S^FU~B>;`EY$ zvWz$tw9M$fV-RiXsKBbqV%P(=5L=R##I}qbq=I= zU*>VXXf0S&v&Z)b=kk9jRrJ*jXXb^gz6k8olkNPs zE!|%J*skt#M#N?$PYCbYNWe(EUbX5y>SWWXlSRPW3Age*cC2g8$^hr|r4@&Q8mw(~ z(4C`Uo9(e`ySiKVJ_NQ-4=;$id1&m6U3EH?>zJOWbG40)JCeUF4fl-PEcwow_yoyg zxmhVR=+o%q8#R2o9;3MPmq=wn)|SImgB=s&JS5qrweOob&$&Xn^>7Y=zp{hf#x;yZOet7YrQAN)Yl4=>MSypuMWmqWOS}(68 zj2E@}+_RfDW6SE)XS5tx*74Ymr#hhHG;To1EZUYJcHxWNcehZ|zmPlJt6klnj7+`w ze9h4f-HAb$;|}gD_vmTa{SY0u5l(DzN2(w41WIm6qIR;<+Z%a_6?)OqQ>ICP6XGsu z3Bp=YpBb9o93Bv4V@TlFLXl{1qFw!+i{BRR86MC_eNsdC#3Rv_{T*hhXR5o4gQU_r zOhY66?qsI8*vUH{Q#^c4MbUNeJ+%7c(yy<7F(u%gGuCaVca^JM;z%A9nI-$DL__a&W&42X;t;RLeGrP*81J%(&j?V; zOmLXlckrrpz--TG6>Tv36E&Wt7rJJ9Z~0gt`)+zho3HI%n|-Kzr!6;@T2@>>iMgIN zlJN96BLj_)8uQ(>35~LwzpTwc&puSNYa=x8ny0$8tne$-u5-smQ;~w%Gt|RHgxc6P zHrL16`p`&(bVj=)Z_@%pB`z2tp4W*{V!G{FDy|(rk*kOpr!;EW;js?zT-n#VjU|!E zw%Ea-JIbWLdZnvFbJzc+HeesJ}*6I**8hRvDJKezk{UtaHan?ii4z zWoqw-v^DmQJrL&=O6$jG#%rQyU*SaVyL$jj%YU* z8HaV8dA55sq9De)1q)ameUHZ7J?|scr@(!i^5&u-?>Gr$)H+#a%=9c?uxt5f=hjH@ zJ#gNKi=23Wi0X$tv&R8e$XUHmiBDHE+ciejA~&3{g^b+|S2OoxeSenHB1zv1kZ z3PH^qLHmBeZJXOkBh~F5-mTJG3eBR795cVG_Bu?(kPj6@`H+tXieN?*9ZV|XjSm?c z>11aNi-rk|v}E8h-iY|E{(<|O3xZ~W8UMaA$9k*2b%`vuX` z(MA_bj4x_mSQ#FIdWP{uoVbjRvS#gBGUJWgh%h=thm`f-`Evedan8{ihg6Jw<7p4U z9NQ1g4ls!dl#!A0o2QpZgk!(zKV7NDAEl_fZ}@=WNOp#1Lj9Xzzw0qcl8Or9lQ&cp zk@!nI=9t*rb2a;(Te!$XYVlxmK_4$4mNCxeJ?fWxzKBb`2s)SL`hh!K%J-ry$c~@V znY?vi6eE&gcj}{}z~l8_W!qK%o*0h_U(nclku_C0+O@GxiS=hHJl4q@%ZVP?wyZ!V zNCxTAl44i=Wu=+Y>1xPiHy3rYuVY1GY2tnX8cbL4Juo1v*K+a#>ge0Px#^(jwbS{- ztBUl=iFW@igt|?CElFuyw@~ddt0`JN#w{YV#w!H1lRo1^nl-n-Z+g8kc*8t7#ETac zZ_?EId;c^;s|RatSHeyfc4Ry{}i>G+S>{nLFM=t+(SN7$ql|95Yt-^97cUAi*z^IKlr^ zI;5Jbnbmo;2#tQfTitI39Z<_hO;#GbH8gyyeHLXw8P~qH+E>?m#kwjxMnnxG>_Vs& zdesuGM;)%v-5BZKnh|NC-$L`K}=QfG!8D;l{`13dX)^C6?T6TZ%G9R#Q`(;hNT z@$_g@<}avs+uz$6K|IVYyXY>>g5|sRgH_vCwu_*9#@?S2@PDr0e*Hs>T7sUxm>!Dr zs=b_Y<-#FJ1tq>0v-%>1mWpT9+g;0@CWo52M}rl|zhBe-pgDUw>M&QB&wT!O@tNkz zJ`zvjIIuf%xI0@9ARb@YiY0h%cdp(qu<$|zXrhiQQo;Z5ctx<^$c-|!{$z)gTg0~+ z^XshlINvzKr+c#MsWngV)$)^fPI5MV$9WO*$n)%mowwb8&Ejm|eJ8E|ck$aZ`VZB| zUHmQKhklD}v;R89{OJ<&x0%E9zn4Ci{qE7y|2*dXM?sH!)|%J2q%mJn@3{5g(9mdP zVN*m&sP)p&8-+(#8}<75G!g}nXIrHv8|vlW+dFT0_wGv(!|B!&J%J_Duhhmh8cXL~ zp}y`Z;x&q$=!lO|;11rYcNKhVy41Dz)3Gq&waRUg(cMU?*KBsKB#5jU9WG}ze; zx;r!TVkfhdfa7(0odnhNDmX##V^g{p4O(ePv!=SLm9)bXRvgQmgL zEh7WMcbLxzNN${5z4O>sml<$dW)W>Ju~$Dt7N^o&B+j8BtKiy0htxQ{VBgbA%`OT0 zq}Uytra3OysvoZO{cz)?)|wA5+o@m7txOyrYkt#hJh69AZ({D|lqhraF9NPEu{1LW z>5A?@DMV=TPjf$?L`QR%8OiZGlX$3{4pooK_(J05!E)X_@9&1&O;J*RD}K*+o}RoC zZ&}nqPd!`B)!Vt$&v0U_RuCGk@2d5a-bR;imFnK5tEaicwi z(@Pf0|B{J8A_(Xits8k}{X{;-t~DkqsyXoqZ-(e?7;(62)^FUH_~f!6+TvmsvU7Q+ za{CIln~LPOZ47ZQw*!zEiJyHx*kNfZ&an$a;fBtnT=kw)?cZ4afDCnl`-WcF404g% z%=6!_)iU#bc6#c$wWoC^^yCS?`d&GVCx z7sk*V6h;e)`|O&Hvva?n{>11!o_KCO_v;$L`cyf$u{6TZ?ZnQ_`%gwsLHjgYZ71KO za9*~XGRtn=kXti3)_mI1z_t2XhCiB9Pyl&G^ET_$S@lciZs>ijW8Ba)-ux%RA5rZ~ zt$tXURVQUNol){60cJdN_${6(zj5(}tfs{k-*PUD59;R(s#S@v2wWV#+w&G12IC}M zy0hK6ZdkCU(QxiM?Y8!Kg=sUJF4qQU`kIg+fEp1~R)5Bthr_}KW1U&3!VB_MZR?Nx zG2l>Bfm*zxZ<{`Bo$XZt3ib5@nULihrFXJkVV-<>Te!{5a5KI^Q_n&Z-F770&T}G`!|3^hQ(3}qU@^FahTk%Yop|LRum0Syv z-Fm~MDmlSMQCPC_&8`yWGz3afH&n)8eJa8`-y)^mVG*+1V}D}P znF_VII6Sc<*T_io7Rx@|%ZYvW1+^1v*1V3rnjiD?<+l3q)6K`FVpc$gG=DIGdh*Y! zcauXY6rSr)5YQZP^!5s_Nph71(CK6pJ@j1k?AqHmubVh>1t$X`HP`3$I5@sdBd<+V zK*OTp;g;K?#lB%SG)FGR@(d zHhyzg<1wGF!|E4LUR;@zNeSog1V9q2*%As5t`|i|M;!hIkC|J%R%Q5<-=^!v2?_^& zQ7U;#r?U*|fvA$6&Vh#-zN*~A(`_M4Ny}R&xfK_;LAmC=ryTOL zS*k*MM|}HNwic%rUj)b!R_2&eef@|>&%J~=g@fx@gP#eAJR8ybtf0JptNC<$+oIde zrk3#}9Qlfh))PQq7WBTU__j2wSUWX?u{DLCByN75)Zn|hxfmX8?KB{xB6@em2_KD~G!(D!>9tG-PdcR!scWAK7h&6*KZ|?>71p z2Iy-s4SDYFn(RK(Rv|+c!>C?4GCz2$qLIbK#MPM0jG3Ewc9Y+?nS*^8CZCp%%@iE| zI;znzQ-+)ykpNkh1Hf53hr%Tn!6UMy(T3&#m8m7U&dSmBXc7_c<+gaQ(&?Ywh6o!1_g8wpPifBVxqIgdUa+1g_K zi*L^#?KpxG>nMR{cT_G^c+eNRdexeqf7`#ZN^b8oKy#SKP-bv^e0$id@@MCEA3LWH zTH@+&Jiky79&g~P1p%y^ddgA%U1knpOMliJdBBqq{QWKV&f941-6I2H{4zU>U__z! z8PuOC)=KDUWW)S3_jaCa?$-~V;b&$(?MiJz&TgW+^C?Tte06-*ZBhQ3NaxKiBAk3Q znM6-|Bb?+SD9u3aV{Fpx6?7@KF0f+sy7i3Gl-l0^{o4;cqke`_?o^lfM>(nBjQj@; zuo+BPNoTddcNDrLHWkPAf~l;EjY)`Dpj-9|LMd02PgB`nN?f3JW6;l=R4#np)Au%q8Jj6McFm&F$Yv`1p?A zvo2Dh43tcD$A7>;>1Wnl&tP5-J7l|%CL)np+d{;irnkrf5tFb83{|c0bf?}C#pUx; zZYvyr_C|0`(4AvR{I=u9=0HmWSqKF@$kg&S96j!3Y7m@2_(MwT?In8Xt;%5r8T^GMT5tnl!G%19to>#YnE`;S{*c8@ij`+$}fhanrR0HP}jRt_0tB?sF988r%1cfjw zt}D+flY%w8zYkTu0=5b2VE21H`<&KSr(#6;>G0Ai%dy3JbRtV zKH)a$CLPq0a~j*5_-3#|MoF`7FyHzij~;e9Z&isiPkWm=Q?$QCI~dJ7E0~cRSC*x> z{sTuPMqhofKtr}AnvxPr5}2q2YwSK}f7ij~%jDdK3c(!91R{hk(_*OU{g0MG+3lXI zc{2lD)|#3;QH4zwlOydse*`BzTC|9Uz-5(8n{z!uiVma5^l3?VRwro8KU4i{6}(5u?w~9v`$L=7o|w zJ>eJ>ai`(=ErEJ1SX#_tqP+SsU zc}?m$za|E!&p^D9<5&ybBszyu;=B9}JDh^Fp%(p3Oii789-imq(Y?W zG2mZVi4u&lSg@C6Znoym)OaFgX3|L!Gpz>Ntsnl4KS>{byG@)C%U-n)6Snl`YT3K{ zPO21dYMAsDvangzkrtW0poDntVE+pr!4f}u%BHiM^Zv!YhKjlB$uJFOF!}pBUo@9N zB#HVD4*68X@Z55R*;j65<+F2Huk!`!nI<6|uaU0pl(vT{ZJ8i#6T)$CpF%dgDlaF( z&bjpG^?jp6)99a$oUootD8f}&a{a}*X40U*_?E))!K0%=yc^P#;g0w8{s4+`fuJDD z_dxR=GJq~K^4_<%>5+b{cXjv2*TGlqJw8u@HPlKiVE09gr0Ut2g6BfiVt9={36weo z`d7h{% zdVD=O=Kawjn?;yBX5}tIkbUjXL6CE5P8Pjw#&Y`#2fK9qUjK!2ZOsObIinq|n~Wy5 zDDO*c|Hf8P!!I)fs~9PHa*ce!Ma2Z@wYwE2p!BltmnXnzS9_)f7(zldH!U&F$k4qm zN~p@*4qO}^F{l%Kcm8dCzE9m)q_u++T}%-5j=877`3rV|N1m^4yNWEm<%cTZX%7D@ zC(N5l_*^gEY%x*{URv@*T5rM~E+amL{dFER*^JV)IK0m~Q4(8js#>9k|Ctw;oe!n zf^wV>?YfEl=1@5zPx#$bO$fjKS-rVAvn3`bX4SSjr`!T=&+F5{s5p+H-_3708ONLK zeYSjYg-q#)iFQ4yQscx>FH&U3-G89=ppZFb;72q=rux|l!}U^OM|U_svZL?qFQN<9 zqSiC_M$~PPe3h#`leL+?5j@=R;?ij>hFl!WBJH)+9anLigKz=D+c4RM2wk#j zdZS;&;XTT}n{iy?PdOSdXr$I$I8BC+Cm-65x6`x-xtFgaYp8!lo6xJ)p$d$Lc3*ax z1A`_ov^_@JJ}$Fjk^3QKyGgv?JatUp1w^!`#eMzdHTQ_tKvM!)JGpGt7Do#7Ca?5- z-I^;iBpXymojk>-sa}MzkngL=E|SaYyOetE>K0-@XEJ1QelQG_vCx$i+2DZRrsV_TfAXj7?a^;H@vyX*d$IWgfgV0*V`|Fj7;bVAls_T%pVdB=$<58NCU0r)J0I_!1*_0L38ql; z#RfIO5??cNf}~hdx4$C>d|%%5kCf4BZQXQcbqx5Yq}vb1%*R`?D^X7#PVMr|PM6=5 zT?)ZF@^0f^i3he5y0G=HN?P? zF37rq-E|r3sS@12K{kpOLyb7prCr5~r-hZEDg<>Hrc*z8%P)uUaf#h|q0Wmc1Db!$ zMcMLv*=?qrzMbjH=}l*Z_9(-z?{wajSX767SnM_)OZJ?~aE`1|Eb+{08$W$jYn~~I z#iaE$67HtUm}%YH3`t|8QMcuzsJk+5q*ve(yr(r7{vIo9@@!dy<>l+8YjZ8fe`@QB zgNaD5nxeG71Py+W%5hCmc7%Z?x>zkDYe)oo|G)DEajC?JqgPk!OO@OByt>gR&Lwc z{Z?=yl?Z30FdBPf($Z!tcBa(w9-jVq=3JAj{vx&uqFr9`>VVE^L0^uH+v8G5d)W zsx5b{%aTE5-=MMYJ)94cQ+(wdYrDK!+FztIzV+i5p3%(Y4a?ySiS7ug8wzhn&64gi zR(Z>(Y|0?U_VD;v3&E=TxZSovP1tU=Iww+?EyCV>X0ECDWkpY~U<@uGW+S<5Hzv*9 zBx{euUDTb^LD?vfej`^*{IXL#f)LH;KJol9SDB-8VN<^ydcRok7^()_qE6(55{KB1 z9J?L&(8&V>oqafeago;%NTggByWmkj;vVJMvIa!+>!tnmra8QoJE>VgZch$R-zaNd z5A`+nZq$%ckalK>odrGWA74P|^3&&Cw{cD^KThvO7y|PA8MPp0+K%+&Sjk&Xl6HiT zJCj-6^$_B6QLxaDAjeHfLg_ORboO1R+G35FnO8o@88f(%>Z#=n5 zEwwr$0~SNP6gfOA64W#hHmaf1jKs=fcY;(Fj#;Zbp6ueG32BOmW`?G2iTUfmy>(GL z&#g?qf&JjRQb~UTb+7BY{iV>4K)Y4QJ4T#NW#cRlYk$ZjqI>G;XS85edSwjOG|zF` z1pJMUPZtbAMDCnd54p^6GO6VL88(D}Nw;rE*2I7yf8M7Lh5DqXFZUI-`%b93IYO4E zdcD)L=nHEetQ7`;LG(Ls1_l4#B5RhO@JUQTkR59l{%HSZ6c6D~4d{4TWr3MdpMLhT zeW`s-`vsVfgt2XHD0_WDVrbmKA(SI#Hb#1P?$b+f^FetBxzzU{P52^qGA)iAYxWK6 zt%mAjM6LSMdLNS(MPLO^eQekTIC>uQh{{n4pc^$Cif36zNUi}XWRTMlWudyD)_IdOCRuc+HS+k{O3HUIS6w2ax|DeY~|cSq?ro3h`Gs7dnqlO zmz^qXnPwyomb-2)2q0o@@KNOFOUmrJNQ4zLSv}A9EY{^bF^mSS?NjSob}eZ|s3x(l zYPSXna|ESlD6#)E2s{5Y0(xd`G&+Fn;tz^RG)_)fKl`VBDf>}Y%MGI8_Z5|p)Q&$z zM-aO_J_J)%eSc9{Guqka8ylcN706Myag52ueTCx^KGeE|*j2V0DVb?yw-V`d zU(a3}zl#J0DFc)H*x%%Bsr3x`c(@l8Y(tGLM{CvYla+m^EmfnCG)Ddz6BtQ0H-`NQ z_c8>WJ+*Q3?-N5@;ul~SZM_aUOb0R{$@rlqC9B&9s0+w;u|GO-Z=myP5+%j!%;0<3 zMBpW}39Dbw`%UB0C+<4$cZLzsc0&$jJ*0IIO=#vq2`ctKNQsIw$74jfm(dDE;DLZsl=Vhl;Fc2{yW`iXpZ6E`uc7KAJ&7J9mFt>VSPfKaax_p1 zTXdY(0=o2+bQt=mfwm*Ks#oD6q|kNy^z}q*Z=5{)REGrlxvq3v<9hq9bCm!C@v`|A~&Di_qhl-ACkzA<~e*I+i#ag24y!z)+_F+GKFH+4YwAEizUf?5g|Z7I^Sw?e2(~@WTLL$TrB9j zWh?`fN8nHacB85B%cCtO6y*_GsWlnu#=f#?-tufao&s82<(sQ2Fo+7YRvodh4U%GCE0e8O%LO2h{P;apo}chnd?{f z>>OYi&8MlJet23#TDdri;H_n~ZU}r5u=PgeAWH;JrlgKcTWCb=ZVU5iMGgX5QX{sQ z+F~K1)oqvIVs=K_9fdU(n%j z{yb`D;^&(%N(0HU?ZDi0Ypi$2@bQ8cqy*~SfcJT$=SY=#u(|6&Zx$0G}Ng@%~4d974)}B1ue8Gw%>g) zLOt#}jxo9XA43_!*uPj8-sydqSUp|TvGachm=WZivWHAU82g6T4+C3j>@+6E&i~j^ z7=Wq96NRXvJ$CGQGRtQiKZE`Qk)?7$9)TdodfMRQU;Xt)m@{z>DBbXvk}UC;HQel4 zb8A&u!~4yE-;~`~SSkE9T65p&6`VoeII22`{X}UhPHFlS{5)J10uAKf@b|ZH1+gz^ zqej!h(F9c>lA;dRP)JIHh>NjiuxX1N+Q#W=6{U1MfU^ZKK_LBvFaZH-cDxOzB}*Rs zAvIak-`1h1vQEoBLMye$gQbUr(fk{-K~f42WXE%71;gGyQ%vds%ve(>2}Lpd69Br* z%J_m$iK2sI-h#7720N`o%qu6|GrM?a&YbIJZ2<`J+0KtB*ExiP9}aI6LTc+Na11)X%(;-U}UCS$-Y4+SbKHIg)GF*9o(<*R#;PIxpTFVzNR?B87^E`a@y$r@IziA!` z3YN~LLgeHkw}@9u^ET~BS(*OIzaZ%{tWl_*P!$bn_2Q`dw3n=X$?;OzeKLm>uvz^; zE!^22ZTi7MerrkG>kF%6W;tsC#fYB+zM&wX#o{24l7Sl*CInjcZ}|5hWl~9EAAP~T z`RS!|o%c+gpUC6L*xMqZ^6a02q{FX9N>FB8>(nTrw9YH^NO43#iJWn~@T6E-!B5IL za)g|FT0Xt3^=dGlX12k}3E;nhHu*pA;~9h|PAHSnbdx>*=lktz()7q8qDK7SGU^3s zpb5(g8ByeEzlc2i3(`P=tHhW%yHDJHC_uuk{@yC4xaZN~- zxq1<=s`N4Y6o({30#2pqhpL;r!#nOSO>esnn&9xR8%1vj2DOA4(MNEf0Ucn3Y$NwR zjS+n9T1HxcuM~-z+eWPO!TzBk5DiOBg;k{N6IXK5yQr&4D~Ja+t*HC2Q}l$<7JH=! znAzHlCOMNhRLJpb*rQ(M7ay~t+M~s~M7v7ItMi;e#Ey$>aeJwHfzc)tg#;K?sRR*` z`!l-)bq{BPg*49(-!YY+z#w{<;1{RVRwW5 zifK0@pr1L7Kn4KjRMYS~w<z>y*WMV6>}BLK(ws(EVn zwUXs1g(HgfHrYr1qZO71(`M(rhsTGemY%hC z5n*K)6BnG5IBL%RjrN(&Oeg1ZL-2Kgx=@KMZ)?*-bvhw!BFIc%LrRH`##k6OvZjx**Snw|-u8Chse!m50 z0Szuj0JDR`sh(o{t8>NnK)Gu7NA8cl7cAKITxbr6oqUBj#{79CNXc?8e+z1|dIf@z zoNcK52~T--_HL3&DEjkXatHPm{J8bDC@as@6P~QFxmVVrTVk8c{LvNCWfGXaZ#_o_%wbEtFk6D^uO$i{r8Ji>XU1LDy}M0#9VvI^t0o z!ad#ws^_Hl5HyRlBOM8=q@6eId(KLsQZdbJ=LIA7)mNx?P;T@U0m8xRCek}yQwV7K zDvMfoo^|2k;AKJfYeg5N5wKHS&)ZtzGgsc-ix?2LiNb(}5izV{BMYRK5GJqG3y1-e z6O9E;lI+%OxkW!U%5&+?q@O02c)3aMcQbGF8&eBC`^BS2OaFYH#{aMAqM!J8f8TuK z;`Frt++Jz+r~Jtx{(|2p81aC@#2+@BZ~ki5e`b6wGyCt~@~&|%Y*+gF=kKleeEF;7 z^5ZKFcsqUlyi5N3zl(K(r{3>|8b}fD_b2_L{2I=VhPmkmktT1>;^8K2kt$}MwYR?x zwLxnk_tikBT>!$s%L+&j*GMH~Y@!c)BJ8BzgeTp3HxhioZ?k%J#xZ=lOY=)@Jm9bC zbCoI?{`g9>^H0MrptTEAHVFdXLvcC3k*z}LRQ|Vt?JLchfkbaTG;A{Ml$XQ){m=s) zgC0k&^UU~FQAczMmiXmUCd&#oE|C|)%qD-8zrptD--YNeW42Hqv-E!iA@I#sP3CH< zMyPhCY146u`e|-vg+lU5%RoCkJTi_Me)j&pRK_ z%HJ?VQB@E8K~i^O)NEC~eC5nf=IBpOazBESeG)(b{@XxD;n){D==Z%`sEZzAuIJ|~ zoFBG&6?zC!b-u+CrCU|vLv<tQ`Ic9J(yHmp~yz-yVkO3Q59!n{?

;& z!KDN>NzcFVCyBd>*`Y*5v!E*{Gv+j-R0dgEthoYZFu}%tnrJGbP=p{mXCW1(6btTo z9>(*EGJI|Afo;xc`#~rhXrY%dD-IQ6LR4PVMaG($5a&*lYC*{t81e9WKC`75?gC(5 z(AcZIPZ$1W+furzjD)W-EB>hHnqRmCc$x%mz$D%PxQ&CgilOTAa0`V9tBjJj$vTTq zj}LosAiUY9n{$#)ra%JGkHnl6GH+Z zljV%cwnT=Y&OwNm$@mxi__lB-0CWfjCYRMX4g!=X z{Xx+p0BD-}p+g1m%BW&R>o}ZCmsP`P>&4k5SKdo4U^$T}gnnbP(%!wn<}x>_M}PAc`m0bX1(Qx` zN0N;k=0cJ%2tvW*<5OS!GCj`XS3=edww)eM5bZJ-VPbQYHml-*%2aLgo72mS;a z5Lh0k7V!vp%uI4C!|>k8kq&m2G^Nu>oz2rrWQ(@|jWtwRX81!PMp95XJlxIE>To3w zpGhAsicZwqPYf82U0b(te7dYhi5CFf(0#Z<2*}IPGRw|7hJACP0So!pfbw4KMj#!; zPDMo#VFR(4RJ_^b!(S~^>n5qR}-_w3QL8Hx_;Q4ShDMzcnuPz~t8W6zzWV8XI^vZXlPuQ@= zJTD3Z7BN_khB``T!axK3(L!yC%1EF*mW*ENz$TGP^*|h1{2ZG7P`?YL^}B^c$c6p! z@eY#j=IYXuDn(>ioD+dcQ9XUU<8<`MwoiD-OjtDqu9u|#ngj62Nz)ekCPUCF(>zND z70}XELK)*u6CI_qiL-305^m3f>L-X7KuHWu%VBM~JW=ZY3Ey0U2IFABAmhzf0BZ?5 z9|;;E=o3?FKr{#cYr=Y^*GVh%*sMwT(s-!_mPTr3vygxi|-j2`gOeK7=qa{+i6*09zbTvoUhHHe6#j z)vdkdx03eHo}DhAN)h`s;?etB8VlG_qt)Msu3Gj9X_t}dkC!KG5lJY_RQryiBpZoIow&?9zh zICdT8TdCe6?8ri6v&7Jc&vDhT=k(C z@{NnvjvYozIRrRGM*_L6()x+Q;!G#UklT`6VM1hpIRUHVY9iZ9N{9Q_17J%iL;S<2 z?51TI?iOf8Hn+fdRr+k0l_K-48Rhg}W1+ydh!gx!U-#Y{ zIv-bZvp@dFPkWQ1-|XH3Htn zbpt)e%sAQsje9#{3hCz1qY|6+FPve6u^&O6PQ_7R=zQq~v^VFbA2vMN)dQsU|kjA|jq z%JPxeAWCDq3V2@TFf3cr8f!8HBB9PvG>N^&ui+sJ6 z`M^*J5uZkM-v&s8A66L_n~~mAfW=@E@zi;}qwwjx=6M1;7<)>%xPA=DZZz>5ul;rE zyTplN<1C}k{Qk&?R|FdF%Gi(78QVO-GNAWkrMcP1d&@{hlQZ4>8bF;Tlo3pXZ%JA( zAy&xDr;L~O1iBAbY1U^8IP2Ls%1iJ;>lrzvTjWrCMDpnSf~NT3Wja+OR9aa*G7LKp zuqm{i;k7r3-|IIVLXf3o=#!bXyTF<_$?Y_c2+$cW0Us>gPWpT{3Hf zKXI>wIzUu*JL@5rMpR%sPSByOGyu-iW<-P)KU>BQ(#4O~Dksma?B}J+p&{xMFeem3%{@D z(vRG-ZKKU+cduH!xZ?fGqtW{6FZd9$I7yb4KEADNb2*`Q4vviYltRk;Qw!!4+B!pW z2%PN}@y1Et-hl)HuR|_d%f&e!bIC4xL$v%2E1j4zrxw#QcjR3<04#{HMQ=qlB2?Ie zLgeH5mw<) z;RPETy?shC*YJZ<3zk9Ft-}Kz21;Hq2)~F^WS^)9vBR&9vxr3NzW}3V@pp|liG@>LT8CZ+x8v;>qb z;{0To4q&BKP&^K$hQ8S;(5})ja$0Rq1_ePLIgkh$TryxY5q+6#t>;{K1Vb7>ew;B7 zWCq9WL&JzMd^!ExLdWl#0 z)+}L?vw;cCqH>WHZ*jWfGwUNYkdlX6rcUC`1S5Jkm#FR1UW_Cpy>Q5iyJZs$_brc+ck1@D zoLYC(XX~zYuYQ7EGsqb(vc{zQD)a!xp}Dr_U+aPPDM2~#nPg8yCz~MUd37kZee6S_z^gHwoa|}v7c_Ftp|^GF2cDmdN5;Y7iF`16 zH59T{5hn*>@;0@08Lk_+_4#U6dD?&Cl_>mV@ z`JWC@o(f70cDtaZ_lTfCimczie1`aP6ah{VN4mg%+b4s`{o>tyh+G9qt&0aNceo7YqR)8f_^e!{dXgWcb9?D~1D!F~%1m zf{0Pv3xqJZicb{u0n$W1dCMpS9*N}%4^rKy?P5AwohD>SM;A{Y2vlUwgxG~|SA2f# zKEHcTjYkWB?t&44Mw2kv-1V=*?0B_~B-lZ&&$TOa0uG8mBL|bD$0Vs|xiR*?_InL=l_rj|Y=RmUN~&v#gUJgmEgdqtxM zjD&20pO5oM8`34006sv_6+wJ+gu)4=kWkLNy?fvPQ;dSU2V*b=Hag*Cd{>#C;0cCq z%|^ThO$iHaKFS*DeN;?R=P_?~Tl;31h=88JuqBjN{Qw7lZFH@pK4u{z)Il!JQ3Wvl$(OwDtY#bNl@)Eh|d2NVX#pHZ^Muup|tpd-=ea~3vz+A>ZNz_46=di*Bj){J%%l&B+sj}_r`U}^@zlxBaz zHnj79rH4{}Z|4hZ_z?Z&!BUQ!jI`v3M{|wAgCQauJ9iX5K3+lmGjauF-EI05HDcs% z#;gckyC!r^QpUfl5u|*gKycZWnVYG7CB~KTAK7!KL<6W(FLNw642DOD@OX2LU2u@k z%4*kI1fI7T`;G1Ey|QG=h?3YzEf*U4;s+XBItv%1H-Cw&sYNJ!f&nikFJlx_yCx>4 zSY+ANWx{RI4VjCz@&me3Vj5tv^t}bCFYWg8QMNg=Y9^C)7P!%tV@s>|1x<u(mYT?90-@QnEmG7M|vMBPVte1g_HvumX+MDNmk$~p{+?ODiI1KW80o6 z8lTZp@aPLiY0~B@k-SHv7;+|~KPA${sFIZsBg&JJINz0;&;XwXI1LDHq5~kcA#CiO z8?nxRLE|{)AybIFDMOO&loq35ZKeaQ0|dt`WOsOJI(CD!v$bY%FsolxtaLhGie7n_j>BX1*8&9pLbzJaO1c_ zdP`J$y^mY(`}}z2XB)^8e_8$_-K=CRYnTQPjP5zDpNX$B2zRUY8)&p&uuK< z8Yn)iAYl6B?zB4@w`VXVgD+21eDEn;EcATdqdWQXs;1N*K0P3>4`&OKsl6+Es601! z@5A!x1(=p0?RnD^^BrQ%w#SQe-@d=HgwhP=LQGCW*^DiUO~jriFbOzf{GXc&<+QqY zXWMw85F03a8&2QWIDD(;{;Zh$AZB-aOt7){KN+o}WcxWsVHip_5!>nnRYc7CRvE=x z@%GwP*WzUKCwhNrT0d|5^o>_^xs^Q_KZjwIBx49JgB-9y;9s)*4gE2#Mq`$dqRpY} z8sl&Lq|%9wrsY`S>iW`YP}AXmtsh{z7yE`7!g}fYM%9lh>mtzZg8kQlB$NOiU`; zt`0li+j_-s-GN0CtJn2!N%S6mG&C(abN>?z4K{YIoeDdZJcl~%u)ks$Ts()SYiJIa z*>Y6M9z%n-WVa|Ejw3*=+DBik_VX>P*G|l4x?X2=P%6V7F!Zx*n|5{fU=Q95=(s%pKS2oI+F6nuDh^pgKHCD zN>}Bic$VycY1V|cKd-lM3Q&*m?%lbr8Mf*gbu`3=<7VJ#zWb!hTe5c@^yo7Z$;a*f z{~NI70S>+Q>-j=EMB>hY?c~FWKo@>^-kpUK>g%e~j&#*v{Pv{x)746cy<0B>q~Kil z42eVYbRtjiCi;DcG&~%m^Jb~U$86)1;CfER?On)J*BfeJ9}A_X!2b^IASU!9DG(} z#_rsXRqPql4}L~tid$$=J;QmjP)oh~KzO@$E$08=*erw8*o68O6x%nnS1dFCD)|Ho ze<94W2Zdm%d6MW>jXl(yvX>J;kEngnXfZBb+5mV&#QudwSC`$uV~D*a z;f7Rn*0T!`yIWk08W>_hB_2Bn_|W0EDBt820v&Yr`NM^a^(-xceax~TcQBb{VzQ=6 z+SJPADG{91MxPke8jMX0AxjHGafT>Nnq_bPlH);LW2!WQX}C(GOcL#3_QBysJo-nnlGXO;LEn>9UiVoAOS;M$q5(GTVkZHK?p}WE%0xp z%YmEfiF#(##3O+Xub9lLRkpf`;$}p1F6hGdo}@ZslZ ztTdky8w6Ze;d0^5?$DqmPjs9G07PF8Bp72DOZ!(%h;AtQd9*FHQu6H!Cvp&1*TFZ@ zKo-@upP)IZf_6W94yua{<}J=@2NmnY6b=~t|9 zK6D5I1r?nb9t%`|pz)S%$jjJq=R1N9%Ekt}fPi9AcRgW~rg#1#;J>}xN6t1!^aUVI z6Q77lT%L$q@6cS1WV9_lltDbejT+{PrIMKMy&P&NUjku^X*d~I|l<~x#g zA!1tTDrXl(Fj?7hZCb$3lKx%=I(LthqTD?;a}%Ean$ zBrxbx4{7G!W?JZiKi3q__appEnP5GgO2Y6BJMXM_Zo5XJPEw9LGG^8wbPqUYq13M_ zOx%c_>65Cu5YYG!Axj+1F){o=j56JNoj@sdW*u#lv~~LQJw7XUq0A+$aX|u2>aC#p zF~)X1L4bU=fq2uc&j8}8P26l;O3qQ}I1|73P@R}$wlb_(a=4A6JmB(SLbquyMz1C> zp{@jy+-FXoe*c>6xWs1Rt&TS$c_`zh0=8>EN>96LrZ)W5UpfS)m*hn9wy+}fJ1M7# zvC&*Pv`&Ud9DkMS>OjWj(I5u5a_D)m-cWBDx*=2C9q2iw2~vj9Lp~C5s%D*uWi(LO z^{T;3C|8qnzzx5`w@f-LpXM%!UC7o2YLLL#q)~^D~QR6GE`y;yjjP*S@{Pljz}JkZC-|CfkA9gQ{@s5ra&}lZ z{@41h>9fDTx~Cxj!`~)feDS|szs97v*xilOe^Y6(Ys0}ev!*G`R=COSKQ;TOO#Nwz zf1KpZ{$6zX_Z`-i~*I zqZ}KbXtfuWz6Dy<4x*uvVUE#kIYvNKVDd>wj&;`T^%UBLd^fjv@CEwxOU=G0YniGA z@UHC!N(aNaF-3$pX+>!dYrG|DAsVg8&lpWpKTGAt!ac8c21F7c|ERfYRrq$Huae4u##EQH~o9%wee!6I>tG&KUE?k@TGa0dld*009X1QQ< zH*}m0l4=kT3Rh%H&neyobWWApH<|B(#`%-|@H;=NS1%{C+J3!4z+be-4wL#)P@u)5 z^EirJbGqoET*@9w`#VBJ>S;BsqD4(rxCTn{!<-7o$xgLyV{a7uTcyN)j>K`7Ap!wK z9*7k@L!SzN*mW%d6X6=MK1v=RV~$9esreBgY;l`vC4j>|h1U_g&5oSnTmsziSqOB1j$KdGlWm#cF4ZKbbBM%b0J(Q~hTyk{sDpId| ziiGA`Cl^KtPYbnhN&Gkk$_d~`UK0NSs0s%1yt(>8KzXn5x0A44QE$luo3?R z;5vL*vjCt~3yL>iEJnTifm+{a>nrETM;udT*~$N33_r#EI8)B(6oW2G@--z4+B8A= zPE&Z}Qy>s!y_^9ike2p}6^a%WbZ#E4-IGiDVKrIPhgp9Bh{loKv7iELT!6-xTujl( z)_7(XZM|p{e%Wt`A+rsu-YiKNsrX#Q&^77kTXbkD;yJh_Kt}M(S-Vc@s&0_cL}v*; zZJ2!kqh7bnO*R`~BK&IEMdt(F|KtqzztreVYyzcb1k0_LE&1zrn*jvUSU&3NzF$i= zPIeK&M&W3@ccx(?|2o8UzupKcB*pDG2bUi<1g_-NHlS`q$Sxu%ARG;rKvp{lQXoF& zBbZ8;-F_$Kqt4)WXfVrw#!=_8DqtBDr4+>r z(&9*K>ZjY`wqbq)wrUr;!|5V$uH1K#_nkaJ4+(7y2Wy7!3yK4e@5&IOqLBN%$bClS z)nfC7+g5~`u2N3B@jB|DgW*SZrY{Zk*ofse7AwXQ`(hV+oR#d6#hk014C>nvoH6ox zpmx2#TIv3^`+m$Fuqgh?!1$p6OC;jpx;Yu3#OJDVaI7+8Cs@ zaRltp-XMJLvBna*6Ve33f}_ETZ>t?$=!o+5#^PuN03O@nc>gKpHiCjVe<{Y3<%SEF z+c|fjV)v_q@_;U&uP^o8z3XI4%Ckq-5AIVQpQ1tiX$t~Mm+VgK-|@E5P-h>OA4%r} zF&GVnEa`e`5uR3*eC(i5lzFx7j9A;>@!9KZr(TaqJ?z#-$><9;P0f?%mdnR|a7eu8 z&m&aff4p1wjKN?HX+yOk*hXZ*+(Bo8RZ!}^OntNKnyZwwv|y3p{N*!*&3(%fh1ua1 zGZK7I1Hcn~4!Lg1t`v+fgNE}_Cx;F}P?jS3L1%OoeK1V4-$QoyU!WY|(*oH~`{6Ez zUjR%?c+%{;8=_6nxBlWyDOpVhg}lJi5EuFUX|e6bzHUH6OFsU!;D3arzGA&GMOMaojeE@tVwSw#9E|XbqKx^*O3+H zoQ3%t+=X(9+frE-TsH~?W!rg~xmQ2GD-=yEV0__g-5rM}RNKq}5L~3;d`!9LA>Da4 zipM+-SYnCYWNmKdecihc;DF>)U(TY~&sbiZkTxK@aV5Y>SU$4XZvoOG8u7aJ3@sxK zCanips+~t68uZ%5oE+f}9vH2!(3^bx`6^jmzzGF_IaY}Z)tNKN5yutQQM>{uFWUB5 zJ1J%cq_7HmN*~*&?^oXYKWv)5pP!KNREk1*Pa|+}yY^J&^8$5=Oe!rBu2<#Iw#2sU zV6&A4S@hY)-Gffs#W|xd#$~9~rm!*$vst^_1B6~FHGGre0_JUv-GOG;{n$`C(&M(r zm$dfbTFMK8Shnxq*+QE-s`zSmYQgEeqqWYX*0uH2{&5u&K~e>vR$=+~zzx{E8$%sL z3ixyai4u&&3Q0ZyuVPm@TMNIXZ|z}yOs&UTgS3Ng6Jf!A_us%hl@z*gYb&}MftNG5 z>X1QV?JKg2tb54;u-1-zV%Xqia=X@Qkh6;ina|TN2#AUQ^el%g0&#MRRE6L&9*pT; zKE5IbdNAwdenOMd*ZY;oU(O>zSNXo1Rx##D`e+LhV2W}F6RVz(t>&{PAosT`8__p? z6f>5DOtC0;!YJ{WFL`xKcbI*)z6+k2EIbX5|5&qQUsE7?0Ismhe*2yKHeh=!2G`yu zV!TpFmh%Yru}KVsiLNeLZoC#)R9*W^LP#@G9<a0?pEKEW!YXDC9Cv9Dk`?uz1)0 zr)k!&Pm|rS5_*1{2L`b3Q{9?9 zM2N?U%uLWbKwzZ1Y=kSjMOMa5QTK9nZoehKAXMaD$C`9Z#6|AJFOVl7OJu+jD^n%M zR|H4<5~H{DPwQ@-_%nn>sX<;-iA)p5vZ&a^do;cYk(=yUvR{vhJtY?faq^Z#LvXd4 zr24H0)d3K5P07gBR-bnhv2#9M}3j!N!V%X*SgxL zgQKH?-qm%mk)>)I`_q=Yg_tPrHS+BBA&l|d?KALRP2>)0T24|>bF9&^|H$i^w_)y{?9QXd$hNGdf{HJXU{-{ynX=DF#&{ay=x(`MfSbOIrQ{Ppd#0#2 ztA%hfF0tt9A(>#{r-^jgP41)L6G}-DsNv_^QTF=zt&X84aYe^K@f(Ixb-hcl_51W}pnOD+oVyxM`yXHN11ZhO_Rd$sx z$GjQ)wCW^C={4v!WDPJY_mv)N(+5RC=Y~Noisu(aP2JK>$m82Yq3A|AO{crFIW9;<$zCNA) zEvTu8JAH6IOSpb00ZZR$4@$h!p@7>qk2CtB4(*ca){XJ~VxqcS~mIpY{W66De^n zT(Ts}avoPMe2q#QIv2Tb-5dIZ7LI@AoEF)>CGjAr`m2}L9g##By`&xs6L@E&?DgTm z1*WK7MQf6|#TPe~02?i$`s728Q0lC=DL&u)Ho+ABM(re(Z#%SzTH&N)JGHuFhX-yL z8yl^8=>D};8k9pY#?A=Yp>#)exc6i?Z{h*~Hyd>A&v*J65Z;cZKo0!psChsO8lue( zh3!4B>(cKnX6jLF8f8D)B^1m41@|T|KFt9#WKCndqEU(j_0R06<@7_6w?}V?U#LOU zZu2tq91ua%SzRH+)W3=%p65;y;>M-eE7W-d9YO%WosE!u@f+J)7;w% z9*LOi0AT9SRRCM&C!vu6b5)?+3;(3G`1snEohjc9-!J5eZ42RyN@=TC4^ddoJA+ix z?Wrm+my~bt^#ZhExLS;wMTGJ-Jm{!lRD_uOxoKlMn9mZh5H18+ky2Xx*bVERtoE;3 z#YS^9tKDHUF)$ekRrBr~8nHcc58)ORC!%(-{Tptvs)zIc*^v_+RUtBN1Rp)J>7@qc z3%10^pG+!v8;V@t_K?^}xu{J>|F6C8jEnkM+up>e$B252M=2JP&@_r5RTKe_5k;Ct z0qGb}KoFE79oCp=q7+>eq_45lMT&Hcimpmi=>jUNAXR#2-)m+UllS@detJH<&hPxr zIjsAio&U@|_uO-r>(bl;ELsd?0Jpwy`qsFa+f4^^iz^@`<+mS>+T}9^Mi@tGwQX$o z6x<#J-E2$EQ5}0dmr%B?H`RO^k+puil_!*fmi;BAHKrX8zy3&aO8Lih`WqZ;;?n^N z6Ga{0w6B|6NI=OLHFUqqUqwDEAs+$$o34r;0>P?5gw7s4P2^M7i}*M%a>GEYJE|I= zFj+4A>u+1LTBlLxPRy^8Ji1{{kw>RTWC&)U!bbd_d%2FrM14)p7yC7fCkY^H#iv(W zeyBm8e6DPQ53?hh$C>99dhp{E$mW#+(nL`wYOX$q>G|9*Y z*)O_MbwAXfogf~ z)q{cqsqa#M>4(3~Sh#3udJl*O4ynFPUj+En{^59gjM z#RFeJ?`F>Dsw0=>`()lV3kvE7^bMfaG5|okK_4Ct<65KBe-Y6WM*l| ze7)^Z!j|5m3_#GwZprf>cjJ&I+@CT#R2I+r*R*P;$U>-=iUpvH~mjLPuc>-)HUEp#v-3{1yIMTkr*=oe^jb#h4rpV};}fy;+o` z+i?Gv+3+(-woJJ|I#<>%d&9Q|7!3WKkLS66c}dT=H^Sy$uC~di>KIuz`HghhvQ*^H zO}@~SGO+;1i4`yti+K~I-t6MqBqzqHTHqv913Og9y}Qs$AOwo$@`_{Qw^DLSDjR#1 zIfn@5I$TiH0z> zEYPy=fA}2~9i$JsbK_sYx&@>*!M@a)$nVbPjCS>Hj&xKKzQxTRmRQDY9H^EmO`Ga*Amiu|p;)~o zFT_N4Y7B^8d3gDFwGn5Yj3nr4A{j!d$fE`XAMX8)$fDxg+{!&uj&5%_#Ke)R3|JTd zZrRp<^t4wl?vYl}>VvzngpKfMYD|zt zM{9C13UaHy5uV_`UUa{a<*X)0>_dT7|C;WecrMB7Z~wF7OphmWtz()myz;X4oCh~q z!A2B=K^2dId^7tM-6q5p01c{JvCC!R5Ee5>q8Uj1NcE+;6Qe;zqYt@G)=8;0u%jL< zA-TY~{X0RS1WlPBri9V8fN%n`M26;E70c!M?I2Ddsmdl;hEaU9Hd?L|_DXG?fRp5G-_aI|tpA8s$it;!!%X+_91YD~@XMg}+i(+WP!tTTD12H@eIguev`tm-uQYla6kF9IAq@$r+iX z7r|%%gaWC(C=KoqV>dPh`XMPFb8oc}PCN3pjCPO0r{UuM>SNQ7h~$dB9tjZ(Uoo9c zC6|xid=T{;Vx?@1#O`uP;>Sg8j|{8S07!1_pGdyGxqWZ3&@O^JvdS` z?k^O~o=iulPU_Tu-R|jdS{DH*(9r2eH|}QQthVtcUJR|syWr^jatfMbw?b#&I zUoqP<9*IkBF2+R&aRD0e@kn>|zwdC{d}u_5)Sm%`GmD@(#{F5WIaYZZB3N8bPL5ov+joP0&wuz1r$ozh-o--|2F<)_ zkZI`?9r;B=6B4&oo5-E3GaXJm{~wo;VeU7ClfryGsx0&D`WJBeikgtJ+3`KQ$>P{{AT~h?n<-Kn>DIvn9$1U1 zD7)raPy*vek03qDh1=;ji=lNow=093nnp>9x(?9*0ej^h-Wp~Ipqacm4l>1~+xr*s zc4L&iNcvk(`T~Sd>msQZf1GEI+Gi_M5nk!)IlCMMzrT^25OZOFVXqYEq1VgycdhdN zR;@d;#``Vg7(AVM|8mGc&c&iY`462qqlU9ljh>os%-6@=KSSvW^Da)l@4wpesQq$E zZZ=ZcbLa+M*mqJqGxb1j=Hmts#&PeqvcQ-OzyHvXx(+Mu)w`Wi!tz>4`1W+NAO^_N) zV82XFRd3hYOA2jEMF~7GVze9kAM=kiHYY;!yan$a_v|C-V)Yh>#hV`Pv6q(6CZ~$l z0^%cOmra2oQ%2)J=HpW+UmSbY#uA@eO-QrvFG^egv>at?wcDdxP{$OeW&baH{BRa!s)Ig&Yd|JyV`ocn%iCPfs7q?wt;gbrQss&A^ zk?wdql_pG>2(;2x;NdF#VD(xMmW>DbxueWj0;PXw7G56Ttjuov*9 z1m6|z<#}=KNS7vPxI;Lk(_EJASaaIf-H% z@a3$tBEj{x)pB}d(U_^Q>GX=Fa^C?{t|D9rP@UVljl?}5uN1qTU1s4PQuY0PC!BgH zi0lniB{(O9ect&FI+af;dfP~LeVlz#3M)}++s>I>fLV)(p^o;$;E<%(apNz%BK1=^So1<{`7m-;AOXNBO&IRoRqM|xq_h-f&d!hYGpaw`6kY8fN9lS zG_v_vICzco_&i^amK=@tG*K8kol^C-Uj_%;o8c#Sh--u`N_Cz-$@SWy!6GpI|IIPa zQ-a#7nO%xP9M9jXW@|fW0xsl+W}JnuBdmvPg!=fq?bQAL?VP4s$(*3gQ30-=3w9su z)ie-TiDlFee&jU@0Z^!SH7zagtD5~Nm_@w~it?cpNDe+&;ggtIBsCZ1&`9_vz2pAC z8!9GS0dsy?A^nG_K&SrBr!2oT-SGeo;k?G*xnaD-P0)?e)O-=^8Se~P=9rB zcv?V;=ZU^icfnq1mNiathVdP>?Av7FEA_cib=O>+IUKdq+LgEgi6fvxFH#C=9jVbEGJm)3H zynAZidQ(?Uzxm@S?#sP2i{Y0Zm7u_Wub{|R-nPW{UM#^fDNN*V#QoN5O(b8q0De$9 zx7S{$;w{uF$7-^MJGO-QwAiL}K~JD;AGisa7uSFQj#n{!P!k9hO9u4`qk^ef%#Crs z4eI$$g#ohMiA7$yl>=`VwStgqCV$7%wC_Yr_^o0+<8>+PzihZyR~8F>g1!~aI!4@jN+7Yoe}f%yx$N^e-223*ZVnN=#J}bd zLP_T$A5YQDNVqFdxK%V>O?c46USG21dnwuZ$zA&I^>yDH=&UM~L9p8?xvs^;a*jvF z$`Qkb{(zH4*E9*`KNH=q;Uh)04vS%40N}#B_ATlQP&Rh-x5YK3DLK*!(;lep_nepB z00bT?mRfCQ>LWqgXZD)6KJaiQBYiP-Ay7|(Nr~lJ@sCdcq9oM)0btp zf8Op^rkrAr^b^gAJKU@^ed1phiaz+{0+%(}DoEQf-MA!de-FeRdpMr}*EH-MeZ!;XK$li_`M{ z^ExVIMO?vqnmSbJZ`A<=#o5yOhb_%4kaqA@1Mnu?2FgJ88he`hGEEAF&lsc^+q z!o9bXb{FNI%+1V|XgMzkUnnIISrEFjxyJ#&oB?p!h8K{x?ZR2{$fld-8k$wWz=*t8 zc}YzX%%$9k7DM<_oS60ltjNqc!$g{wrt;j}8gmlYI+#oq>a_F*2<5r@WMp`iiI1$; zHJ+*vuI@Bt2_=g)YcQ|Z`R8+QDpDP=g;>TV@_7?IH4euJG-JCE5^SxrvrxX1(l(IbYWcJO!?k7=XRR&-i|n!dmthvp zyYK`_cpdfP%q(XmhuMbes0=NQAlFC+TOy=x^;me=m_G>6llkzoKaae6k(W98US@xXsQS2n#pzk>xB#lT71+NGEG z0#>avQI3p|QR$;nuQ1li=ja=zDs&r=?rN?*{B3@cF7w76XVd;W30d)y^+@48DCZO+ zvIt;&o3pdS%5eiHuZ|i%T)~8h-Er~*tjxvKXN|6zmyZt4WgRmoo))bM{z#AVrf}5% z0v}09L$pTW5biVzhxAZyoN_k!v^XK&1{PBFZ<#TV#9bHIrsc0Ej#p%OfCM>sMHG;g z_(uT}I_o|#x#wE>VnzIHCL`lJ?sXiqK<}US7Gq(#6F}NRSw%%a#;$R%+MY7foxaWKfp_IrObSrWf|r{#a$a|S z&_P1C^0u1%UibiT&i?=q)jWQNFT*3Et^3l7p?G?jxKio;7a~^5X8v8bJHBB*xe}HQ z6wSj^ecH>VmdIqUS4<5YZc?rHA80VBDeJ0t(oXbxMGD&qitucuhj33x>=0Zr7$>{4 z5_Xn4Q#IOMMs9{GmD?6&X`e-wT_i&v?%K|8I`JkKO|qL+fBvy^Q!S8Y<-G4j+mAGb z-?nw~(^>p@Wd=3I0308x zVXWLQoaUpSF;p`nD>h>JK5ox;&|3^r>>w+?hzBa0#VW#I?cg;vmGa$Su&Ciw(QR8; z@YS}ZE06CVa4rtWco%WMrE%f-xEU`le!~?{QLmwv7Sy|&-Z2ZR_(J8uA2R$S46euS zYH{4FD5v8lit(R(I!ZAdXdk4s@rzBChK8LAuP>}IbR$zmzZo9REAj!fkQ!%wN#GZ1S6>)Ci&0`me7<8`6zCDIj++@AX}BQa zFrafL6j5axt~sFxzj`8Ink&bSs4zYT)#s1M=Xd^9-kP9jub`-N$SvyCeM$M;W1?T? z+U)#tUXIpZ%HMxcpe6Nd%&&eI|CuZMs%`u0+KWc#s>SU?-msB3o`IaY@%*| zYuaB*{;Vt@M)vE2lUx~2Op7!6*24A#;^YTCIxJW!QqXX8m; z*I;Pp6#r&{%(l_ z+&#jLcW?6x3K;464ILcEUfoR;IFMOGlHjzSHe^Vt^%ijN%_Sw2l@0)kjgXsxc<4%% zSZ^@Ca6J*e6cU)CQ0}*ue*gV`<@iDi#UE~B*kk*pi4$e75P1ORu-wj zXI}mY?ll4826e#c^9O{qcIPpL=+4+aqfP&2Gfj5 z0I!(R+1=M4x(42%tLZW|4!nbz-H>W_%qpXI4;eS}2-yt|y9~EnbbfVda>$xWY!xhn zw%EM_$*ME*6KAdPbR@6+n&v?}Y%+vH!RH{%C{p2Nv+?d^zuh3gzay(=h0<&N_8JAdv z`rf8Y2O{mEdE*^ddsbeeYRYE+rZ?RjD=G^NIg&rey#Dc-A_rJc@FS3Su(`Ycjgy=D^LG}U5Lye0F^&)mmgj+((u={i$%GF-WqB7|7c+8g}? z%>}q6hf_R2hR|FXq_AaYCdjmDNw}iRu7*NwEQc87^VDyx+(zo!I_S$PW#CA3FJB{2 zvR%o(es1F-oJ(tn+%ipwNvX9Luwf+Y>TnhjSEWR_i%@n9VJB^K=dWIpe!xh;>vvK7 zv*_|XPn7tPF|rpExum#pEQUitrjE2!gtpYP`6O;b)VZe+1iRgsO zl{FCsotD+LC+KUGd`ur5w0k(4ZkHyKH>{-#kQj}@%iNM#M5m?SNU=^>XGS$To~YY1 z5W3UAe+QXt8+l?%Rl0%X1k4ShA_NhwBp#I-&aY)5Ie=} zCxlq4!lY5ZNn|;_x|4HVpAVo%#$>Z&uw1Ws^vzXUJE93}(&?{cl!m3bSiQ1OAo3%0 z-MvEpeZ?n5`+K|&CoMo-w8-i{f$KK&bAS2slXd!ieW5Yol*ec+_%@si0hn;Fg0YF2pj`E~x5AO4MLG9+r0 zDJ-HW4ZIRN-V|B9kB1Lu*}b@K_t&q266=yAwc#h+bBc5two$rLtG|FhCZh?iB5U#v zCT;+~dvkh5o~+8w3(K57Lgh(lDOo*lQIH)y(ug2O1%_qy*Je%otQM3rNeit^`-UsCZ=>m-@d^kzwhDGE$Ya(V02U?pZ|2r;k zn|pRH+{nfrE%{j891q9xa{!tNR{HYz@7&B0sk1e7rNtsN;b{ecWgQO$(vN}9=>b6~| zsvGq5aSBA&^__&aB&05x4%fyT?ibWA|NVsG8nw2)oM+p~+cDqJirry$b~ZP~Rlj*C z^=e@uaKf}&HvxV5#r2L(KG6&Nm|k>TY>rw~o;BGyHXHjCzJ<2>5|kIO+;kCmC2QG- z#NU9FmlXGt%$s+S=zP{tfE(#uc|b{!;=9x%ba>R9p>^p+KJUL8`O<*%t-6a802pUD#H!%b<9JWA8euswrxR z`~&1_08AVyIYf9bfYXeCmIC%&u&TV>x}k+YzG=An-l!A;oyrbv?yi>$2 z8(rmq$0ncI&`9<@K^4S7fad&SgM!-7xE;@k&f%fgg}ZUspTis?VXI}jyW`;m`e9oc zsvOxBszFetNlP!gw1%SwA2vGSr~7=pxiy%(H(?j46#tTuZeM+vv$^>^pu_}f(k7

27nVjmP-s*WnefuXoyGAWgP1XWV@@b@Vl<;X4CzcLQABu$ zf1?%~J`!*Pl09_K3%M=hLo<6wQb$&;|E*oq0?zNL&o{t?SEL(h&hz>x=Y|W2PCb90 z>A_KSdsxg7XV<@ea`wiQB&fX*`AU9&B2^WJH_WPX!+khAL z={eDs$k>s3yJ>=U2{=kB;|mVIdtx{6{%UbC*%hhv>EGpcBnk{fE!Lo%PI+UjUK716 z@yq$yClItx7}ImBJ3zWQUT)s}Gi*`P7F>>4QS_bKL@TGgnZ8aEJ?ZACqPDjx%vf&H zbzrGFSm&JU&ixIW$9DH{E7j)!xiQ-pgJ!TVde=j=gl1)#DKkbto=*Bs0+?@kgFmT~ z#lsv)T&hcm9m(Y~ZK+~Rp=j=>&I1weOyqap?1X8?N|tNi?9f6<6kCtu%SZfjIr(zvmPQm5lDzC29e=b6TIw(r-$Z->bv zSL<&W?4_QbQlKozI5Y#r;&5c{iWmzI!{=T{i5xP|y!jqe;XctF2`3u>95Lfent>7{_BHmjlvhzQ~5}k;_C*o$+ykDtUh^XiNsmIFt~F z%neX7jSUyJ0Eg3F;dU|o)pBZEErY)sIYEeNlY>yK9>f{%%aCV0vj{_aQIXE0&KVtv zS2g~m&!y7QE_=}wXW?Z@RuIP_#7z6n(F6U!*1vP;)68acM?4(ht*-AZ!Sun>rtK9GNih2lkqOHa{e zE;o-+f^%JeeMS(kcO{+6E+f&rC3z9Qaa*;Hd)q^6?~W!1?t z1be<^==`{KpPdZ>U6y}|u%5=0Be5;eL=G!1lN4?WUy)k{ZwbRegBadgtg9&UE>MH| z>8O!nu*YQ0E|4_C$+4N=_4F8$l$kVk>siTpFiqdrj}~!#6z>?ayMKzSYrPAZO!&EU znK)4Na15oF$dIc78nb3Y@9*5Va#Fg@=Ux_^5yp@Hygp(sY!bWo&kjF6isb=VtNNJT z2f$~h2d||z@0zMpQ&%Wn$J&}HPy8snx1gO7q@!P*U}77+sD2OLIK z*s^Fs7h!~aT)=r4MEq%R5)qT__0pjsk~rP;Ov%>++Rnx#4yh$$6*BJc1+Xy8n#y8H zpQpm%;GyY`qpNeun7=xG)rYIQ&w`hER%u#Lp~qrI`W5Yv#OaBEw^X6yH`&U;G4U%V zYKR+)s1G-?Kbb4DYpp(j^3;OxmK;~6*qlB zwL2a*t})((K1Hl>832%!Ax{a;!>eUpOBt0%4`+20zlc$#=P9<=Dw8A^lUJiMob$^= zfWO7~hB(9+`3}!65p>7uB$C1I(|#Ekms%`kXmxISf2m`^>{7mQf}tVg>E_0L$ocaH%gmgd9n4Z1 zi0aB79wup4S3-9M_KvX90<$n#nGs0Da7xvwNfDq1;c;8L3~mVL%zRdP50$478eb7n zLW%*f5J#G0fEU@{DjvE1OOVug9~&zDLXq%a;iq8^8@U=Ti;ad4==qk*WZP=B(nQud zO&!(P-Q?~f&hzdcd;#JY_LzsK6oT@AW`{YP%I`qhNOw-CZfNfJABLoVL64hhK~$urTkt4usM*`p@I2u%QAs#qFGsC}^l^uNl7dL%?<+DR%9VD^Kk!q)H4z7kI zqlLWZL_L8G1AC{(Qh<(ZTX+k)L6w?FQwnZxND1Bl)P9t}AP7-G*sfCh^LKn?x( z7D*FCpMNptm=szysrY}mM($@2el&-L5_zoIH>96d=#+wf3v$K>!PMH;*=vYYGl2Fa zebGw%d;ek_Z$YTXby6Fov@>?Z!losxZRL6q1sg<_i|C?Jxk^uW*aK}UCjX^;!O)0_ z>bg{`8D@v*`m$jpq0IEV^1TU2=Jz*@T&2b z>VD(n4djyU)O;>D#)?=LOcT1;LQYR3q(A}0dsi$eKEAuZavO&5t)r7J5}tr1j}&6M z-{FBRfc9CI5dTbb%?A-v+QUkUt03f;Wkxa-L74f(Vj__!|e77Fk(K)}V162t?i$nUrSWNL6%h_5<8{63c(+vkCE9@(@ zl8?%c#%>>-l9n&q5lU%YO8BUj(a$puR*wPR7DqC15N`*-Bw)`7z^i(OK-Pr}?*Xbym=Fn1BN%}oNb+; ztV8KHsP2{pQVLpoVx|KZn)rmt}VA4yKEJdzB%q2w%mn}T2zY$T-_kYwZvo?32eC~)3nEoYb5fZ?DfRlg?9 z@VbGlU|WPZ{b=$$$Uz}}Q~gN$)ubCZ{SfM+$IELYkK@KyB0xD098;LkW8s=Q@xF;3 zqKy2*ShWyWy_ytV?NGXv$wQ8S4efctWmeLis&x#r0(TL$$soI87$2 z#HElh#i>aFEu%O1d%4{3aAfrmH>}|FhzLE;(a>cDge8`qvAs^I{Me#Y;JQv>If@E4 zg{iYh$~@9Y7lF@Di0Lzsc>66%&97)vT}bSM4F|0h?6-T~dHZa;jybWRU`0>}Dw}hJ ztHJ-6=PG}B*sugd4h*IqqFSA>{E@)bVf^agQ!3?HR@?TdeGC3$FPk*ytnCd-HOH=} z%hdJ7tXb9Mp5WK3Pl`|g71Sd$FN+2p z*35<+^p!SJT~lsmykMD3naaMuEi$3o+1Vt$PaOd_ROn9DCrF|Cp4c~?aAhKo13+~* zVp%X!1oqA4z2fKfnsUTems`X>v$#auBC|2wX;co`Y&k;t1O4j@B>l?L$_G!7-GT5I zqf3LF*UULe$Ge6Xw*p)}9SbRqeq=#je z*h{UH(5=P4Yqb(*3d7;cA>Ry#J)VKBswn7`Ap~An29Pg@+b76c{wtGY+em~ZpE`pa z8xkAgpCR6kL&I0?-4tBXtQRh^pYGUw`%b>ccVsdQ$uY*zzn&s%n$F%7X&HeOHD#O0 zhFTVCZpP{>b_su3cHPQwp}M?jZpNm$D|W-CA?*i=N0Lq9iO1gJr&{B-di$j1BE@|n+@cAVYoH>Ln(yRSgor1sGi9wak67Wk zPvO|lg0Y%y$JIVBM@&4lN8eXk3h~XVC03KShs&R5;dlA5bcWeZb8{p%#4P~S>8d1` z1CvR`iKW3|>ObZE+hO03uR>5D95 zIX90rif_)>jU;AH8r4ti-!Ldpc6qW0h%~QHljJQG4QFw8X#UVR_d<28Nz=`{&ZfBY z6OS;N3Y+}$9y_R&ubJ*HQpyZO)xx!oy*oQ3SfS^u+m287SDeoe9CO;N zZF2uX=$`7d&W`%&*3doMCML+*pQ5DL4Bq9mv`v^V&A`z&bgF_jk98Iwmyl4(*9m6< z()#jSTzL=_rRvD4*=nw|uz!&tNX0kln32`Bk?@v!#`~=wn;yYQh!V-BCrCTGHvdG^ z1}EbT@4jKh0AImmVZ;+)_`>gOgwYl{NZepjBc(voO{ z9Xm33(~YOk%w2h5aOh$xL`!xK-ear3$hy{T0GHo_`&E%koA_9 zjB5@*{0IC2s8(&d?p^Xlb<~#p@izXhAmb3@qwhV;$^wC|QQLzhIYEC>>lVkwVPSxiU)>5AhmbQb*74R^fiCaw zN{K*Hxu<(|HuA>k*+(9JA%VB%tBV?ED#Io5APTQU#a@E zcLq+|fdm=DEy^CnVTWGSL_8IikQ$;$k9>UthkWq4+Na;sYNdk3{AJUo?68${pH>7} zHBFC>wl)S8cy^`*mMd25+xAGfVc5xPm!F@);g`C+SM62RWfzVdsZN)_73gYdSv}Cx z)6w>+Lq5VRIY;vIeOD-6cZSbT&?=)}dCvTQ{{IRIKu*zK(n-|G35Xo6S!p z{_zDy-S2$UZ|1WSZE_0#qp46x>AycnUrg-wzZ)}zxNZN@oFDwHH_68~668AdkG7nT zOZ4fJ&qE+n@p%?L&%(@>_$&*bW#O|d%*?=NS(urL&wAmrEc|~d3u|9_sG#gqkK_W( zOaSAv9Z9C*^DKOxg_$k!Sr$Ia!e?2SnSsx;Ff$YXn|fjQ#Bt|Y-wKsz$$vX+z1H^L zoXUbC zf&A}hdH5_3pXGtU_-seVy9cc~I*qRclKW|K~6mC3DrYPB)6QTDZ^t zzdj$$pX%o4rO$j^67;YC5pxd)6vIE; + + + + + """ + + email = flask.request.form["email"] + if flask.request.form["password"] == users[email]["password"]: + user = User() + user.id = email + flask_login.login_user(user) + return flask.redirect(flask.url_for("serve_sphinx_docs")) + + return "Bad login" + + +@app.route("/logout") +def logout(): + flask_login.logout_user() + return "Logged out" + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + return flask.redirect(flask.url_for("login")) + + +@app.route("/") +@app.route("/") +@flask_login.login_required +def serve_sphinx_docs(path="index.html"): + return app.send_static_file(path) + + +if __name__ == "__main__": + app.run(debug=True) diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 000000000..d6b127d18 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,119 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + + +# -- Project information ----------------------------------------------------- + +project = "gt4sd" +copyright = "GT4SD team 2022" +author = "GT4SD team" + +# -- Generate API (auto) documentation ------------------------------------------------ + + +def autodoc_skip_member_handler(app, what, name, obj, skip, options): + """Skipping tests.""" + return name.startswith("test_") + + +def run_apidoc(app): + """Generage API documentation""" + import better_apidoc + + better_apidoc.APP = app + better_apidoc.main( + [ + "better-apidoc", + "-t", + os.path.join(".", "_templates"), + "--force", + "--no-toc", + "--separate", + "-o", + os.path.join(".", "api"), + os.path.join("..", "src", "gt4sd"), + os.path.join("..", "src", "gt4sd", "*test_*"), + ] + ) + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", + "sphinx_rtd_theme", + "myst_parser", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" +html_logo = "_static/gt4sd_logo.png" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + + +# -- Extension configuration ------------------------------------------------- +add_module_names = False + + +napoleon_google_docstring = True +napoleon_include_init_with_doc = True + +coverage_ignore_modules = [] +coverage_ignore_functions = [] +coverage_ignore_classes = [] + +coverage_show_missing_items = True + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# -- Options for todo extension ---------------------------------------------- + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +def setup(app): + app.connect("autodoc-skip-member", autodoc_skip_member_handler) + app.connect("builder-inited", run_apidoc) diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..0f10887ef --- /dev/null +++ b/docs/index.md @@ -0,0 +1,21 @@ +# Generative Toolkit for Scientific Discovery + +## Site contents + +```{toctree} +--- +maxdepth: 2 +--- +Examples on how to use the GT4SD algorithms +Examples on how to add an algorithm to GT4SD +``` + +## Python API + +```{toctree} +--- +maxdepth: 1 + +--- +API of the gt4sd package +``` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 000000000..2119f5109 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/gt4sd_algorithm_addition_md.md b/docs/source/gt4sd_algorithm_addition_md.md new file mode 100644 index 000000000..11dc9c853 --- /dev/null +++ b/docs/source/gt4sd_algorithm_addition_md.md @@ -0,0 +1,358 @@ + +# Adding a new algorithm + +## Getting started + +The general structure of a single conditional generation algorithm in `gt4sd-core` is shown here + +```{code-block} sh +gt4sd-core + |gt4sd + | |algorithms + | | |conditional_generation + | | | |__init__.py + | | | |[My_Algorithm] + | | | | |__init__.py + | | | | |core.py + | | | | |implementation +``` + +At the time of writing these are the only files you will need to be aware of to add your own custom algorithm to `gt4sd-core`. Here we will talk through the implementation of a template algorithm we have called {py:class}`Template`, this algorithm will take a string input and return a list with the single item `Hello` + input, i.e. input=`World` outputs the list `[Hello World]`. + +Since `Template` is a conditional generation algorithm, I have created the `My_Algorithm` folder (`template`) in the `conditional_generation` folder, and inside added the 3 files `__init__.py`, `core.py`, and `implementation.py`. + +## Implementation + +Starting with the file `implementation.py` we have the following code + +```{code-block} python +class Generator: + """Basic Generator for the template algorithm""" + + def __init__( + self, + resources_path: str, + temperature: int + ): + """Initialize the Generator. + + Args: + resources_path: directory where to find models and parameters. + + """ + + self.resources_path = resources_path + self.temperature = temperature + + def hello_name( + self, + name: str, + ) -> List[str]: + """Validate a list of strings. + + Args: + name: a string. + + Returns: + a list containing salutation and temperature converted to fahrenheit. + """ + return [ + f"Hello {str(name)} {random.randint(1, int(1e6))} times and, fun fact, {str(self.temperature)} celsius equals to {(self.temperature * (9/5) + 32)} fahrenheit." + ] +``` + +Here we have created a class called {py:class}`Generator` with 2 functions: + +```{code-block} python +___init__(self, resources_path: str, temperature: int) +``` + +which is used to initialise the generator, set addional parameters ( in this case `temperature` is the addional parameter ) and the directory from where the model is located, and + +```{code-block} python +hello_name(self, name: str) -> List[str] +``` + +which is the actual implementation of the algorithm. For this guide our algorithm takes in a string `name` and `temperature` and outputs a single string `Hello name a random number of times and temperature in fahrenheit` in a list. + +For your specific algorithm this second function will be your own code. + +## Core + +Now we will look into the file `core.py` + +```{code-block} python +import logging +from typing import ClassVar, Optional, TypeVar, Callable, Iterable, Any, Dict + +from ...core import AlgorithmConfiguration, GeneratorAlgorithm # type: ignore +from ...registry import ApplicationsRegistry # type: ignore +from .implementation import Generator # type: ignore + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T") +S = TypeVar("S") +Targeted = Callable[[T], Iterable[Any]] + + +class Template(GeneratorAlgorithm[S, T]): + """Template Algorithm.""" + + def __init__( + self, configuration: AlgorithmConfiguration[S, T], target: Optional[T] = None + ): + """Template Generation + + Args: + configuration: domain and application + specification, defining types and validations. + target: Optional, in this inistance we will convert to a string. + + Example: + An example for using this temmplate:: + + target = 'World' + configuration = TemplateGenerator() + algorithm = Template(configuration=configuration, target=target) + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to hello_name from generator. + + Args: + configuration: helps to set up the application. + target: context or condition for the generation. Just an optional string here. + + Returns: + callable generating a list of 1 item containing salutation and temperature converted to fahrenheit. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.hello_name # type:ignore + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(Template) +class TemplateGenerator(AlgorithmConfiguration[str, str]): + """Configuration for specific generator.""" + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + g + temperature: int = field( + default=36, + metadata=dict( + description="Temperature parameter ( in celsius )" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + Returns: + target description. + """ + return { + "title": "Target name", + "description": "A simple string to define the name in the output [Hello name].", + "type": "string", + } + + def get_conditional_generator(self, resources_path: str) -> Generator: + return Generator( + resources_path=resources_path, + temperature=self.temperature + ) + +``` + +`core.py` will contain at least two classes. The first is named after your algorithm, in our example this class is called {py:class}`Template`, which is initialised with a `GeneratorAlgorithm` object. The second is an `AlgorithmConfiguration`, in this case called {py:class}`TemplateGenerator`, which is used to configure your algorithm. + +### Template + +Your algorithm, {py:class}`Template` for us, needs to contain at least two functions. + +```{code-block} python +__init__(self, configuration, target) +``` + +This is used to initialise the algorithm by passing in the algorithm configuration and an optional parameter. The configuration parameter is the object created from the `TemplateGenerator` class and the `target` parameter in this case will be string we are passing through to our algorithm. + +```{code-block} python +get_generator(self, configuration, target) +``` + +This function is required get the implementation from the generator configuration. It then returns the function in the implementation with corresponds with your algorithm. In our case this is {py:class}`implementation.hello_name`. + +```{code-block} python +validate_configuration(self, configuration) +``` + +This is a optional helper function to validate that a valid configuration is provided. A similar validation method could be created to check that a user has added a valid input or `target` in our case. + +### TemplateGenerator + +Finally you will need to create a specific configuration for your algorithm, In our case called {py:class}`TemplateGenerator`, note that in our implementation we have tagged this class with `@ApplicationsRegistry.register_algorithm_application(Template)`. This decorator is needed to add the algorithm to the `ApplicationRegistry,` you should add a similar decorator to your implementation of `AlgorithmConfiguration` replacing the `Template` name in the decorator with the name of your algorithm. + +In this class there are three required strings `algorithm_type`, `domain`, and `algorithm_version` which are all self explanatory: + +- `algorithm_type` is the type of algorithm you are implementing, i.e., `generation`. +- `domain` is the domain your algorithm is applied to, i.e., `materials`. +- `algorithm_version` is the version of algorithm you are on, i.e., `v0`. + +These strings will set the location for resource cache of the model. +Make sure you create the appropriate path in the S3 storage used (default bucket name `algorithms`, `algorithms/{algorithm_type}/{algorithm_name}/{algorithm_application}/{algorithm_version}`) where your artifacts will be uploaded: `algorithms/conditional_generation/Template/TemplateGenerator/v0`. + +There are two required functions for our configuration: + +The first function needed is + +```{code-block} python +get_target_description(self) -> Dict[str, str] +``` + +which returns a dictionary defining the type of `target`, for our algorithm this is a string, and both a title and description of what that `target` represents. This method is needed to populate documentation for the end user. + +The final function needed is + +```{code-block} python +get_conditional_generator(self, resources_path: str) -> Generator +``` + +which is used to return the Generator from the resource path. + +Note that if we wish to implement specific configurations for this algorithm this can also be set by creating additional `AlgorithmGenerator`s in `core.py` and adding each parameter via a `field` object i.e. + +```{code-block} python + algorithm_type: ClassVar[str] = 'conditional_generation' + domain: ClassVar[str] = 'materials' + algorithm_version: str = 'v0' + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + Returns: + target description. + """ + return { + "title": "Gene expression profile", + "description": "A gene expression profile to generate effective molecules against.", + "type": "list", + } + + def get_conditional_generator( + self, resources_path: str + ) -> ProteinSequenceConditionalGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return ProteinSequenceConditionalGenerator( + resources_path=resources_path, + temperature=self.temperature, + generated_length=self.generated_length, + samples_per_protein=self.batch_size, + ) +``` + +`field` is used to set a default configuration and a description of the parameter which is used to populate the documentation returned to the end user similar to `get_target_description`. Algorithm configuration parameters can be validated by adding the implementation of a `__post_init__` method as described [here](https://docs.python.org/3/library/dataclasses.html#post-init-processing). + +## Final steps + +Finally to complete our implementation we need to import all the algorithms and configurations in our created `__init__.py` folder like so + +```{code-block} python +from .core import ( + Template, + TemplateGenerator, +) + +__all__ = [ + 'Template', + 'TemplateGenerator', +] +``` + +and to automatically add the algorithm to the registry without any manual imports, we have to import the generator class which in our case is `TemplateGenerator` to the outermost `__init__.py` of the subdirectory `algorithms`. + +```{code-block} python +from .template.core import TemplateGenerator +``` + +## Using a custom algorithm + +Now that the new algorithm is implemented we can use it the same was as shown before + +### Explicitly + +```{code-block} python +from gt4sd.algorithms.conditional_generation.template import ( + TemplateGenerator, Template +) +target = 'World' +configuration = TemplateGenerator() +algorithm = Template(configuration=configuration, target=target) +items = list(algorithm.sample(1)) +print(items) +``` + +### Registry + +```{code-block} python +from gt4sd.algorithms.registry import ApplicationsRegistry +target = 'World' +algorithm = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type='conditional_generation', + domain='materials', + algorithm_name='Template', + algorithm_application='TemplateGenerator', +) +items = list(algorithm.sample(1)) +print(items) +``` diff --git a/docs/source/gt4sd_inference_usage_md.md b/docs/source/gt4sd_inference_usage_md.md new file mode 100644 index 000000000..c0f6651f3 --- /dev/null +++ b/docs/source/gt4sd_inference_usage_md.md @@ -0,0 +1,97 @@ +# GT4SD inference examples + +```{note} +You can {download}`Download the source file for this page <./gt4sd_inference_usage_md.md>` +``` + +```{contents} +:depth: 2 +``` + +## Overview + +This notebook show the basic usage of the GT4SD algorithms. + +- running an algorithm explicitly calling the implementation +- use the {py:class}`ApplicationsRegistry` to instantiate and call the algorithms + +### A note on the setup + +In the following we assume that the toolkit has been setup using the provided `conda.yml` as follows: + +```{code-block} sh +# create and activate environment +conda env create -f conda.yml +conda activate gt4sd +# install the toolkit +pip install . +``` + +## Running algorithms explicitly + +To run algorithms explicitly we only need to instantiate a {py:class}`GeneratorAlgorithm` +and the companion {py:class}`AlgorithmConfiguration`. +Then based on the actual algorithm type we might need to pass a `target` for generation. + +Next we see an example of {py:class}`PaccMannRL`, +a `conditional_generation` algorithm: + +```{code-block} python +from gt4sd.algorithms.conditional_generation.paccmann_rl.core import ( + PaccMannRLProteinBasedGenerator, PaccMannRL +) +target = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT' +configuration = PaccMannRLProteinBasedGenerator() +algorithm = PaccMannRL(configuration=configuration, target=target) +items = list(algorithm.sample(10)) +print(items) +``` + +For vanilla generation algorithms (`generation`), like {py:class}`PolymerBlocks`, +the usage is analogous, but no `target` is required: + +```{code-block} python +from gt4sd.algorithms.generation.polymer_blocks.core import ( + PolymerBlocksGenerator, PolymerBlocksGenerator +) +configuration = PolymerBlocksGenerator() +algorithm = PolymerBlocksGenerator(configuration=configuration) +items = list(algorithm.sample(10)) +print(items) +``` + +## Running algorithms via the registry + +Here we show how the toolkit algorithms can be instantiated and run using the {py:class}`ApplicationsRegistry`. + +Next we see an example of {py:class}`PaccMannRL`, a `conditional_generation` algorithm: + +```{code-block} python +from gt4sd.algorithms.registry import ApplicationsRegistry +target = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT' +algorithm = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type='conditional_generation', + domain='materials', + algorithm_name='PaccMannRL', + algorithm_application='PaccMannRLProteinBasedGenerator', + generated_length=5, +) +items = list(algorithm.sample(10)) +print(items) +``` + +Similarly we can use the registry to run {py:class}`PolymerBlocks`: + +```{code-block} python +from gt4sd.algorithms.registry import ApplicationsRegistry +algorithm = ApplicationsRegistry.get_application_instance( + algorithm_type='generation', + domain='materials', + algorithm_name='PolymerBlocks', + algorithm_application='PolymerBlocksGenerator', + generated_length=10, +) +items = list(algorithm.sample(10)) +print(items) +``` \ No newline at end of file diff --git a/extras_requirements.txt b/extras_requirements.txt new file mode 100644 index 000000000..2aa26a7b8 --- /dev/null +++ b/extras_requirements.txt @@ -0,0 +1,3 @@ +# extras requirements from custom pypi repositories +cogmol_inference==0.5.0 +AMD-Analytics==3.2.13 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..12b60b81b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.black] +line-length = 88 +skip-string-normalization = false +target-version = ['py37'] + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 88 +force_to_top = ["rdkit", "scikit-learn"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..e33b55a81 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +# pypi requirements +minio==7.0.1 +numpy>=1.16.5 +torch>=1.0 +torchmetrics<0.7 +typing_extensions>=3.7.4.3 +pydantic>=1.7.3 +tape-proteins>=0.4 +scikit-learn<=0.24.2 +scikit-optimize>=0.8.1 +pytorch_lightning<=1.3.1 +regex>=2.5.91 +transformers>=4.2.1 +sentencepiece>=0.1.95 +datasets>=1.11.0 +keybert==0.2.0 +reinvent-chemistry==0.0.38 +tensorboard!=2.5.0,>=2.2.0 +rdkit-pypi>=2020.9.5.2 +# vcs requirements +reinvent_models @ git+https://github.com/GT4SD/reinvent_models@v0.0.1 +guacamol_baselines @ git+https://github.com/GT4SD/guacamol_baselines.git@v0.0.1 +pytoda @ git+https://github.com/PaccMann/paccmann_datasets@0.1.1 +paccmann_predictor @ git+https://github.com/PaccMann/paccmann_predictor@sarscov2 +paccmann_chemistry @ git+https://github.com/PaccMann/paccmann_chemistry@0.0.4 +paccmann_generator @ git+https://github.com/PaccMann/paccmann_generator@0.0.2 +paccmann_omics @ git+https://github.com/PaccMann/paccmann_omics@0.0.1.1 +paccmann_gp @ git+https://github.com/PaccMann/paccmann_gp@02b3463 +terminator @ git+https://github.com/IBM/regression-transformer@gt4sd diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..533da8852 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,156 @@ +[metadata] +name = gt4sd +version = attr: gt4sd.__version__ +description = Generative Toolkit for Scientific Discovery (GT4SD). +author= GT4SD team +long_description = file: README.md +keywords = GT4SD Generative Models Inference Training +python_requires = >= 3.7.* +classifiers = + Operating System :: OS Independent + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + +[options] +install_requires = + minio + numpy + torch + torchmetrics + typing_extensions + pydantic + scikit-learn + scikit-optimize + pytorch_lightning + regex + transformers + sentencepiece + datasets + keybert + reinvent-chemistry + tensorboard + rdkit-pypi +setup_requires = + setuptools +package_dir = + = src +packages=find_namespace: +include_package_data = True + +[options.entry_points] +console_scripts= + gt4sd-trainer = gt4sd.cli.trainer:main + gt4sd-pl-to-hf = gt4sd.cli.pl_to_hf_converter:main + gt4sd-hf-to-st = gt4sd.cli.hf_to_st_converter:main + +[options.packages.find] +where = src + +[options.package_data] +gt4sd = + py.typed + training_pipelines/*json + training_pipelines/tests/*json + training_pipelines/tests/*smi + +[options.extras_require] +extras = + cogmol-inference + AMD-Analytics + +[flake8] +max-line-length = 80 +select = C,E,F,W,B,B950 +ignore = E203, E501, W503 + +[mypy] +check_untyped_defs = True +plugins = pydantic.mypy + +[mypy-pytest.*] +ignore_missing_imports = True + +[mypy-rdkit.*] +ignore_missing_imports = True + +[mypy-setuptools.*] +ignore_missing_imports = True + +[mypy-minio.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-pandas.*] +ignore_missing_imports = True + +[mypy-paccmann_chemistry.*] +ignore_missing_imports = True + +[mypy-paccmann_omics.*] +ignore_missing_imports = True + +[mypy-pytoda.*] +ignore_missing_imports = True + +[mypy-tape.*] +ignore_missing_imports = True + +[mypy-skopt.*] +ignore_missing_imports = True + +[mypy-regex.*] +ignore_missing_imports = True + +[mypy-transformers.*] +ignore_missing_imports = True + +# to avoid mypy from crashing (https://github.com/python/mypy/issues/11045) +[mypy-transformers.trainer] +check_untyped_defs = False + +[mypy-torch.*] +ignore_missing_imports = True + +[mypy-keybert.*] +ignore_missing_imports = True + +[mypy-sentence_transformers.*] +ignore_missing_imports = True + +[mypy-cog.*] +ignore_missing_imports = True + +[mypy-pag.*] +ignore_missing_imports = True + +[mypy-reinvent_chemistry.*] +ignore_missing_imports = True + +[mypy-reinvent_models.*] +ignore_missing_imports = True + +[mypy-guacamol_baselines.*] +ignore_missing_imports = True + +[mypy-AMD_Analytics.*] +ignore_missing_imports = True + +[mypy-paccmann_predictor.*] +ignore_missing_imports = True + +[mypy-paccmann_gp.*] +ignore_missing_imports = True + +[mypy-selfies.*] +ignore_missing_imports = True + +[mypy-sklearn.*] +ignore_missing_imports = True + +[mypy-joblib.*] +ignore_missing_imports = True + +[mypy-terminator.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..bac24a43d --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +import setuptools + +if __name__ == "__main__": + setuptools.setup() diff --git a/src/gt4sd/__init__.py b/src/gt4sd/__init__.py new file mode 100644 index 000000000..dbd77982d --- /dev/null +++ b/src/gt4sd/__init__.py @@ -0,0 +1,4 @@ +"""Module initialization.""" + +__version__ = "0.22.0" +__name__ = "gt4sd" diff --git a/src/gt4sd/algorithms/__init__.py b/src/gt4sd/algorithms/__init__.py new file mode 100644 index 000000000..501fb685b --- /dev/null +++ b/src/gt4sd/algorithms/__init__.py @@ -0,0 +1,43 @@ +"""Module initialization for gt4sd.""" +from ..extras import EXTRAS_ENABLED +from .conditional_generation.guacamol.core import ( # noqa: F401 + AaeGenerator, + GraphGAGenerator, + GraphMCTSGenerator, + OrganGenerator, + SMILESGAGenerator, + SMILESLSTMHCGenerator, + SMILESLSTMPPOGenerator, + VaeGenerator, +) +from .conditional_generation.key_bert.core import KeyBERTGenerator # noqa: F401 + +# here we import the applications to register them +from .conditional_generation.paccmann_rl.core import ( # noqa: F401 + PaccMannRLOmicBasedGenerator, + PaccMannRLProteinBasedGenerator, +) +from .conditional_generation.reinvent.core import ReinventGenerator # noqa: F401 +from .conditional_generation.template.core import TemplateGenerator # noqa: F401 +from .controlled_sampling.advanced_manufacturing.core import ( # noqa: F401 + CatalystGenerator, +) +from .controlled_sampling.paccmann_gp.core import PaccMannGPGenerator # noqa: F401 +from .generation.hugging_face.core import ( # noqa: F401 + HuggingFaceCTRLGenerator, + HuggingFaceGPT2Generator, + HuggingFaceOpenAIGPTGenerator, + HuggingFaceTransfoXLGenerator, + HuggingFaceXLMGenerator, + HuggingFaceXLNetGenerator, +) +from .generation.polymer_blocks.core import PolymerBlocksGenerator # noqa: F401 +from .prediction.topics_zero_shot.core import TopicsPredictor # noqa: F401 + +# extras requirements +if EXTRAS_ENABLED: + from .controlled_sampling.class_controlled_sampling.core import ( # noqa: F401 + PAG, + CogMol, + ) + from .generation.molgx.core import MolGXQM9Generator # noqa: F401 diff --git a/src/gt4sd/algorithms/conditional_generation/__init__.py b/src/gt4sd/algorithms/conditional_generation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/__init__.py b/src/gt4sd/algorithms/conditional_generation/guacamol/__init__.py new file mode 100644 index 000000000..a34e87b72 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/__init__.py @@ -0,0 +1,27 @@ +"""GuacaMol initialization.""" + +from .core import ( + AaeGenerator, + GraphGAGenerator, + GraphMCTSGenerator, + GuacaMolGenerator, + MosesGenerator, + OrganGenerator, + SMILESGAGenerator, + SMILESLSTMHCGenerator, + SMILESLSTMPPOGenerator, + VaeGenerator, +) + +__all__ = [ + "GuacaMolGenerator", + "SMILESGAGenerator", + "GraphGAGenerator", + "GraphMCTSGenerator", + "SMILESLSTMHCGenerator", + "SMILESLSTMPPOGenerator", + "MosesGenerator", + "VaeGenerator", + "AaeGenerator", + "OrganGenerator", +] diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/core.py b/src/gt4sd/algorithms/conditional_generation/guacamol/core.py new file mode 100644 index 000000000..ae8ac7cec --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/core.py @@ -0,0 +1,673 @@ +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import ( + AaeIterator, + Generator, + GraphGAIterator, + GraphMCTSIterator, + OrganIterator, + SMILESGAIterator, + SMILESLSTMHCIterator, + SMILESLSTMPPOIterator, + VaeIterator, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=Any) +Targeted = Callable[[T], Iterable[Any]] + + +class GuacaMolGenerator(GeneratorAlgorithm[S, T]): + """GuacaMol generation algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """ + Instantiate GuacaMolGenerator ready to generate samples. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating molecules given a scoring function and a score:: + + config = SMILESGAGenerator() + target = {"scoring_function_name": {"target": 0.0}} + algorithm = GuacaMolGenerator(configuration=config, target=target) + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to perform the prediction via Guacamol's generator. + + Args: + configuration: helps to set up specific application of Guacamol. + + Returns: + callable with target generating samples. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_batch # type: ignore + + +@ApplicationsRegistry.register_algorithm_application(GuacaMolGenerator) +class SMILESGAGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate optimizied molecules using SMILES Genetic algorithm""" + + algorithm_name: ClassVar[str] = GuacaMolGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + population_size: int = field( + default=100, + metadata=dict( + description="it is used with n_mutations for the initial generation of smiles within the population" + ), + ) + n_mutations: int = field( + default=200, + metadata=dict( + description="it is used with population size for the initial generation of smiles within the population" + ), + ) + n_jobs: int = field( + default=-1, + metadata=dict(description="number of concurrently running jobs"), + ) + gene_size: int = field( + default=2, + metadata=dict( + description="size of the gene which is used in creation of genes" + ), + ) + random_start: bool = field( + default=False, + metadata=dict( + description="set to True to randomly choose list of SMILES for generating optimizied molecules" + ), + ) + generations: int = field( + default=2, + metadata=dict(description="number of evolutionary generations"), + ) + patience: int = field( + default=4, + metadata=dict( + description="it is used for early stopping if population scores remains the same after generating molecules" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for SMILES.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> SMILESGAIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return SMILESGAIterator( + resource_path=resources_path, + population_size=self.population_size, + n_mutations=self.n_mutations, + n_jobs=self.n_jobs, + random_start=self.random_start, + gene_size=self.gene_size, + generations=self.generations, + patience=self.patience, + batch_size=self.batch_size, + ) + + +@ApplicationsRegistry.register_algorithm_application(GuacaMolGenerator) +class GraphGAGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate optimizied molecules using Graph-Based Genetic algorithm""" + + algorithm_name: ClassVar[str] = GuacaMolGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=1, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + population_size: int = field( + default=100, + metadata=dict( + description="it is used with n_mutations for the initial generation of smiles within the population" + ), + ) + mutation_rate: float = field( + default=0.01, + metadata=dict( + description="frequency of the new mutations in a single gene or organism over time" + ), + ) + offspring_size: int = field( + default=200, + metadata=dict(description="number of molecules to select for new population"), + ) + n_jobs: int = field( + default=-1, + metadata=dict(description="number of concurrently running jobs"), + ) + random_start: bool = field( + default=False, + metadata=dict( + description="set to True to randomly choose list of SMILES for generating optimizied molecules" + ), + ) + generations: int = field( + default=2, + metadata=dict(description="number of evolutionary generations"), + ) + patience: int = field( + default=4, + metadata=dict( + description="it is used for early stopping if population scores remains the same after generating molecules" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for SMILES.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> GraphGAIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return GraphGAIterator( + resource_path=resources_path, + batch_size=self.batch_size, + offspring_size=self.offspring_size, + population_size=self.population_size, + mutation_rate=self.mutation_rate, + n_jobs=self.n_jobs, + random_start=self.random_start, + generations=self.generations, + patience=self.patience, + ) + + +@ApplicationsRegistry.register_algorithm_application(GuacaMolGenerator) +class GraphMCTSGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate optimizied molecules using Graph-based Genetic Algorithm and Generative Model/Monte Carlo Tree Search for the Exploration of Chemical Space""" + + algorithm_name: ClassVar[str] = GuacaMolGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=1, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + init_smiles: str = field( + default="", + metadata=dict(description="initial SMILES used for generation of states."), + ) + population_size: int = field( + default=100, + metadata=dict( + description="it is used with n_mutations for the initial generation of smiles within the population" + ), + ) + n_jobs: int = field( + default=-1, + metadata=dict(description="number of concurrently running jobs"), + ) + generations: int = field( + default=1000, + metadata=dict(description="number of evolutionary generations"), + ) + patience: int = field( + default=4, + metadata=dict( + description="it is used for early stopping if population scores remains the same after generating molecules" + ), + ) + num_sims: float = field( + default=40, + metadata=dict(description="number of times to traverse the tree"), + ) + max_children: int = field( + default=25, + metadata=dict(description="maximum number of childerns a node could have"), + ) + max_atoms: int = field( + default=60, + metadata=dict( + description="maximum number of atoms to explore to terminal the node state" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for SMILES.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> GraphMCTSIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return GraphMCTSIterator( + init_smiles=self.init_smiles, + batch_size=self.batch_size, + population_size=self.population_size, + max_children=self.max_children, + num_sims=self.num_sims, + generations=self.generations, + n_jobs=self.n_jobs, + max_atoms=self.max_atoms, + patience=self.patience, + ) + + +@ApplicationsRegistry.register_algorithm_application(GuacaMolGenerator) +class SMILESLSTMHCGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate optimizied molecules using recurrent neural networks with hill climbing algorithm""" + + algorithm_name: ClassVar[str] = GuacaMolGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=1, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + n_jobs: int = field( + default=-1, + metadata=dict(description="number of concurrently running jobs"), + ) + n_epochs: int = field( + default=20, + metadata=dict(description="number of epochs to sample"), + ) + mols_to_sample: int = field( + default=1024, + metadata=dict(description="molecules sampled at each step"), + ) + keep_top: int = field( + default=512, + metadata=dict(description="maximum length of a SMILES string"), + ) + optimize_n_epochs: int = field( + default=2, + metadata=dict(description="number of epochs for the optimization"), + ) + max_len: int = field( + default=100, + metadata=dict(description="maximum length of a SMILES string"), + ) + optimize_batch_size: int = field( + default=256, + metadata=dict(description="batch size for the optimization"), + ) + benchmark_num_samples: int = field( + default=4096, + metadata=dict( + description="number of molecules to generate from final model for the benchmark" + ), + ) + random_start: bool = field( + default=False, + metadata=dict( + description="set to True to randomly choose list of SMILES for generating optimizied molecules" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for SMILES.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> SMILESLSTMHCIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return SMILESLSTMHCIterator( + resource_path=resources_path, + batch_size=self.batch_size, + n_epochs=self.n_epochs, + mols_to_sample=self.mols_to_sample, + keep_top=self.keep_top, + optimize_n_epochs=self.optimize_n_epochs, + max_len=self.max_len, + optimize_batch_size=self.optimize_batch_size, + benchmark_num_samples=self.benchmark_num_samples, + random_start=self.random_start, + n_jobs=self.n_jobs, + ) + + +@ApplicationsRegistry.register_algorithm_application(GuacaMolGenerator) +class SMILESLSTMPPOGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate optimizied molecules using recurrent neural networks with hill climbing algorithm""" + + algorithm_name: ClassVar[str] = GuacaMolGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=1, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + num_epochs: int = field( + default=20, + metadata=dict(description="number of epochs to sample"), + ) + episode_size: int = field( + default=8192, + metadata=dict( + description="number of molecules sampled by the policy at the start of a series of ppo updates" + ), + ) + optimize_batch_size: int = field( + default=1024, + metadata=dict(description="batch size for the optimization"), + ) + entropy_weight: int = field( + default=1, + metadata=dict(description="used for calculating entropy loss"), + ) + kl_div_weight: int = field( + default=10, + metadata=dict( + description="used for calculating Kullback-Leibler divergence loss" + ), + ) + clip_param: float = field( + default=0.2, + metadata=dict( + description="used for determining how far the new policy is from the old one" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for SMILES.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> SMILESLSTMPPOIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return SMILESLSTMPPOIterator( + resource_path=resources_path, + batch_size=self.batch_size, + num_epochs=self.num_epochs, + episode_size=self.episode_size, + optimize_batch_size=self.optimize_batch_size, + entropy_weight=self.entropy_weight, + kl_div_weight=self.kl_div_weight, + clip_param=self.clip_param, + ) + + +class MosesGenerator(GeneratorAlgorithm[S, T]): + """Moses generation algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """ + Instantiate GuacaMolGenerator ready to generate samples. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating molecules given a scoring function and a score: + + config = AaeGenerator() + algorithm = MosesGenerator(configuration=config, target="") + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to perform the prediction via Guacamol's generator. + + Args: + configuration: helps to set up specific application of Guacamol. + + Returns: + callable with target generating samples. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_batch # type: ignore + + +@ApplicationsRegistry.register_algorithm_application(MosesGenerator) +class AaeGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate molecules using Variational autoencoder""" + + algorithm_name: ClassVar[str] = MosesGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + n_samples: int = field( + default=20, + metadata=dict(description="Number of SMILES to generate"), + ) + n_batch: int = field( + default=1024, + metadata=dict(description="Batch size for the optimization"), + ) + max_len: int = field( + default=100, + metadata=dict(description="Maximum length of the generated SMILES"), + ) + + def get_conditional_generator(self, resources_path: str) -> AaeIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return AaeIterator( + resource_path=resources_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) + + +@ApplicationsRegistry.register_algorithm_application(MosesGenerator) +class VaeGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate molecules using Adversarial autoencoder""" + + algorithm_name: ClassVar[str] = MosesGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + n_samples: int = field( + default=20, + metadata=dict(description="Number of SMILES to generate"), + ) + n_batch: int = field( + default=1024, + metadata=dict(description="Batch size for the optimization"), + ) + max_len: int = field( + default=100, + metadata=dict(description="Maximum length of the generated SMILES"), + ) + + def get_conditional_generator(self, resources_path: str) -> VaeIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return VaeIterator( + resource_path=resources_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) + + +@ApplicationsRegistry.register_algorithm_application(MosesGenerator) +class OrganGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate molecules using Objective-Reinforced Generative Adversarial Network""" + + algorithm_name: ClassVar[str] = MosesGenerator.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + n_samples: int = field( + default=20, + metadata=dict(description="Number of SMILES to generate"), + ) + n_batch: int = field( + default=1024, + metadata=dict(description="Batch size for the optimization"), + ) + max_len: int = field( + default=100, + metadata=dict(description="Maximum length of the generated SMILES"), + ) + + def get_conditional_generator(self, resources_path: str) -> OrganIterator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return OrganIterator( + resource_path=resources_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/__init__.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/__init__.py new file mode 100644 index 000000000..0fef4edde --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/__init__.py @@ -0,0 +1,599 @@ +"""GuacaMol algorithms implementation module.""" + +import json +import logging +import os +from typing import Any, Dict, List, Tuple, Type, Union + +from guacamol_baselines.graph_ga.goal_directed_generation import GB_GA_Generator +from guacamol_baselines.graph_mcts.goal_directed_generation import GB_MCTS_Generator +from guacamol_baselines.moses_baselines.aae_distribution_learning import AaeGenerator +from guacamol_baselines.moses_baselines.organ_distribution_learning import ( + OrganGenerator, +) +from guacamol_baselines.moses_baselines.vae_distribution_learning import VaeGenerator +from guacamol_baselines.smiles_ga.goal_directed_generation import ChemGEGenerator +from guacamol_baselines.smiles_lstm_hc.goal_directed_generation import ( + SmilesRnnDirectedGenerator, +) +from guacamol_baselines.smiles_lstm_ppo.goal_directed_generation import ( + PPODirectedGenerator, +) + +from .....domains.materials.scorer import SCORING_FUNCTIONS, CombinedScorer +from .graph_ga import GraphGA +from .graph_mcts import GraphMCTS +from .moses_aae import AAE +from .moses_organ import Organ +from .moses_vae import VAE +from .smiles_ga import SMILESGA +from .smiles_lstm_hc import SMILESLSTMHC +from .smiles_lstm_ppo import SMILESLSTMPPO + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +def get_target_parameters( + target: Union[str, Dict[str, Any]] +) -> Tuple[List[Type[Any]], List[float]]: + """Generates a tuple of scorers and weight list + + Args: + target: scoring functions and parameters related to it + + Return: + A tuple containing scoring functions and weight list + """ + score_list = [] + weights = [] + target_dictionary: Dict[str, Any] = {} + if isinstance(target, str): + target_dictionary = json.loads(target) + elif isinstance(target, dict): + target_dictionary = target + else: + raise ValueError( + f"{target} of type {type(target)} is not supported: provide 'str' or 'Dict[str, Any]'" + ) + for scoring_function_name, parameters in target_dictionary.items(): + weight = 1.0 + if "weight" in parameters: + weight = parameters.pop("weight") + score_list.append(SCORING_FUNCTIONS[scoring_function_name](**parameters)) + weights.append(weight) + return (score_list, weights) + + +class Generator: + """Abstract interface for a conditional generator.""" + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + raise NotImplementedError( + "Implementation not found for generation of molecules." + ) + + +class SMILESGAIterator(Generator): + def __init__( + self, + resource_path, + batch_size: int, + population_size: int, + n_mutations: int, + n_jobs: int, + random_start: bool, + gene_size: int, + generations: int, + patience: int, + ): + """Initialize Generator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + batch_size: number of molecules to generate + population_size: used with n_mutations for the initial generation of smiles within the population + n_mutations: used with population size for the initial generation of smiles within the population + n_jobs: number of concurrently running jobs + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + gene_size: size of the gene which is used in creation of genes + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + """ + self.resource_path = resource_path + self.batch_size = batch_size + self.population_size = population_size + self.n_mutations = n_mutations + self.n_jobs = n_jobs + self.random_start = random_start + self.gene_size = gene_size + self.generations = generations + self.patience = patience + self.chemGenerator: ChemGEGenerator = None + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + score_list, weights = get_target_parameters(target) + + self.scoring_function = CombinedScorer( + scorer_list=score_list, + weights=weights, + ) + if self.chemGenerator is None: + optimiser = SMILESGA( + smi_file=os.path.join(self.resource_path, "guacamol_v1_all.smiles"), + population_size=self.population_size, + n_mutations=self.n_mutations, + gene_size=self.gene_size, + generations=self.generations, + n_jobs=self.n_jobs, + random_start=self.random_start, + patience=self.patience, + ) + logger.info("Initialization of the Generator") + self.chemGenerator = optimiser.get_generator() + + logger.info("generating molecules") + molecules = self.chemGenerator.generate_optimized_molecules( + self.scoring_function, self.batch_size + ) + return molecules + + +class GraphGAIterator(Generator): + def __init__( + self, + resource_path, + batch_size: int, + population_size: int, + offspring_size: int, + n_jobs: int, + mutation_rate: float, + random_start: bool, + generations: int, + patience: int, + ): + """Initialize Generator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + batch_size: number of molecules to generate + population_size: used for the initial generation of smiles within the population + n_jobs: number of concurrently running jobs + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + offspring_size: number of molecules to select for new population + mutation_rate: frequency of the new mutations in a single gene or organism over time + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + """ + self.resource_path = resource_path + self.batch_size = batch_size + self.population_size = population_size + self.n_jobs = n_jobs + self.random_start = random_start + self.offspring_size = offspring_size + self.mutation_rate = mutation_rate + self.generations = generations + self.patience = patience + self.gb_ga_generator: GB_GA_Generator = None + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + score_list, weights = get_target_parameters(target) + + self.scoring_function = CombinedScorer( + scorer_list=score_list, + weights=weights, + ) + if self.gb_ga_generator is None: + optimiser = GraphGA( + smi_file=os.path.join(self.resource_path, "guacamol_v1_all.smiles"), + population_size=self.population_size, + mutation_rate=self.mutation_rate, + offspring_size=self.offspring_size, + generations=self.generations, + n_jobs=self.n_jobs, + random_start=self.random_start, + patience=self.patience, + ) + logger.info("Initialization of the Generator") + self.gb_ga_generator = optimiser.get_generator() + + logger.info("generating molecules") + molecules = self.gb_ga_generator.generate_optimized_molecules( + self.scoring_function, self.batch_size + ) + return molecules + + +class GraphMCTSIterator(Generator): + def __init__( + self, + init_smiles: str, + batch_size: int, + population_size: int, + max_children: int, + n_jobs: int, + num_sims: float, + max_atoms: int, + generations: int, + patience: int, + ): + """Initialize Generator. + + Args: + init_smiles: path where to load hypothesis, candidate labels and, optionally, the smiles file. + batch_size: number of molecules to generate + population_size: used for the initial generation of smiles within the population + max_children: maximum number of childerns a node could have + n_jobs: number of concurrently running jobs + num_sims: number of times to traverse the tree + max_atoms: maximum number of atoms to explore to terminal the node state + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + """ + self.init_smiles = init_smiles + self.batch_size = batch_size + self.population_size = population_size + self.max_children = max_children + self.n_jobs = n_jobs + self.num_sims = num_sims + self.max_atoms = max_atoms + self.generations = generations + self.patience = patience + self.grah_mcts_generator: GB_MCTS_Generator = None + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + score_list, weights = get_target_parameters(target) + + self.scoring_function = CombinedScorer( + scorer_list=score_list, + weights=weights, + ) + if self.grah_mcts_generator is None: + optimiser = GraphMCTS( + init_smiles=self.init_smiles, + population_size=self.population_size, + max_children=self.max_children, + num_sims=self.num_sims, + generations=self.generations, + n_jobs=self.n_jobs, + max_atoms=self.max_atoms, + patience=self.patience, + ) + logger.info("Initialization of the Generator") + self.grah_mcts_generator = optimiser.get_generator() + + logger.info("generating molecules") + molecules = self.grah_mcts_generator.generate_optimized_molecules( + self.scoring_function, self.batch_size + ) + return molecules + + +class SMILESLSTMHCIterator(Generator): + def __init__( + self, + resource_path, + batch_size: int, + n_epochs: int, + mols_to_sample: int, + n_jobs: int, + random_start: bool, + optimize_n_epochs: int, + benchmark_num_samples: int, + keep_top: int, + max_len: int, + optimize_batch_size: int, + ): + """Initialize Generator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + batch_size: number of molecules to generate + n_epochs: number of epochs to sample + mols_to_sample: molecules sampled at each step + keep_top: molecules kept each step + optimize_n_epochs: number of epochs for the optimization + benchmark_num_samples: number of molecules to generate from final model for the benchmark + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + n_jobs: number of concurrently running jobs + max_len: maximum length of a SMILES string + optimize_batch_size: batch size for the optimization + """ + self.resource_path = resource_path + self.batch_size = batch_size + self.n_epochs = n_epochs + self.mols_to_sample = mols_to_sample + self.keep_top = keep_top + self.optimize_n_epochs = optimize_n_epochs + self.benchmark_num_samples = benchmark_num_samples + self.random_start = random_start + self.n_jobs = n_jobs + self.max_len = max_len + self.optimize_batch_size = optimize_batch_size + self.smiles_lstm_hc_generator: SmilesRnnDirectedGenerator = None + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + score_list, weights = get_target_parameters(target) + + self.scoring_function = CombinedScorer( + scorer_list=score_list, + weights=weights, + ) + if self.smiles_lstm_hc_generator is None: + optimiser = SMILESLSTMHC( + model_path=os.path.join(self.resource_path, "model_final_0.473.pt"), + smi_file=os.path.join(self.resource_path, "guacamol_v1_all.smiles"), + n_epochs=self.n_epochs, + mols_to_sample=self.mols_to_sample, + keep_top=self.keep_top, + optimize_n_epochs=self.optimize_n_epochs, + max_len=self.max_len, + optimize_batch_size=self.optimize_batch_size, + benchmark_num_samples=self.benchmark_num_samples, + random_start=self.random_start, + n_jobs=self.n_jobs, + ) + logger.info("Initialization of the Generator") + self.smiles_lstm_hc_generator = optimiser.get_generator() + + logger.info("generating molecules") + molecules = self.smiles_lstm_hc_generator.generate_optimized_molecules( + self.scoring_function, self.batch_size + ) + return molecules + + +class SMILESLSTMPPOIterator(Generator): + def __init__( + self, + resource_path, + batch_size: int, + episode_size: int, + num_epochs: int, + optimize_batch_size: int, + entropy_weight: int, + kl_div_weight: int, + clip_param: float, + ): + """Initialize Generator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + batch_size: number of molecules to generate + episode_size: number of molecules sampled by the policy at the start of a series of ppo updates + num_epochs: number of epochs to sample + optimize_batch_size: batch size for the optimization + entropy_weight: used for calculating entropy loss + kl_div_weight: used for calculating Kullback-Leibler divergence loss + clip_param: used for determining how far the new policy is from the old one + """ + self.resource_path = resource_path + self.batch_size = batch_size + self.episode_size = episode_size + self.num_epochs = num_epochs + self.optimize_batch_size = optimize_batch_size + self.entropy_weight = entropy_weight + self.kl_div_weight = kl_div_weight + self.clip_param = clip_param + self.smiles_lstm_ppo_generator: PPODirectedGenerator = None + + def generate_batch(self, target) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + score_list, weights = get_target_parameters(target) + + self.scoring_function = CombinedScorer( + scorer_list=score_list, + weights=weights, + ) + if self.smiles_lstm_ppo_generator is None: + optimiser = SMILESLSTMPPO( + model_path=os.path.join(self.resource_path, "model_final_0.473.pt"), + num_epochs=self.num_epochs, + episode_size=self.episode_size, + optimize_batch_size=self.optimize_batch_size, + entropy_weight=self.entropy_weight, + kl_div_weight=self.kl_div_weight, + clip_param=self.clip_param, + ) + logger.info("Initialization of the Generator") + self.smiles_lstm_ppo_generator = optimiser.get_generator() + + logger.info("generating molecules") + molecules = self.smiles_lstm_ppo_generator.generate_optimized_molecules( + self.scoring_function, self.batch_size + ) + return molecules + + +class AaeIterator: + def __init__( + self, + resource_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize AAE. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.resource_path = resource_path + self.model_path = os.path.join(self.resource_path, "model.pt") + self.config_path = os.path.join(self.resource_path, "config.pt") + self.vocab_path = os.path.join(self.resource_path, "vocab.pt") + self.n_samples = n_samples + self.n_batch = n_batch + self.max_len = max_len + self.aae_generator: AaeGenerator = None + + def generate_batch(self, target=None) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + if self.aae_generator is None: + optimiser = AAE( + model_path=self.model_path, + model_config_path=self.config_path, + vocab_path=self.vocab_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) + logger.info("Initialization of the Generator") + self.aae_generator = optimiser.get_generator() + molecules = self.aae_generator.generate(self.n_samples) + return molecules + + +class VaeIterator: + def __init__( + self, + resource_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize VaeIterator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.resource_path = resource_path + self.model_path = os.path.join(self.resource_path, "model.pt") + self.config_path = os.path.join(self.resource_path, "config.pt") + self.vocab_path = os.path.join(self.resource_path, "vocab.pt") + self.n_samples = n_samples + self.n_batch = n_batch + self.max_len = max_len + self.vae_generator: VaeGenerator = None + + def generate_batch(self, target=None) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + if self.vae_generator is None: + optimiser = VAE( + model_path=self.model_path, + model_config_path=self.config_path, + vocab_path=self.vocab_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) + logger.info("Initialization of the Generator") + self.vae_generator = optimiser.get_generator() + molecules = self.vae_generator.generate(self.n_samples) + return molecules + + +class OrganIterator: + def __init__( + self, + resource_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize OrganIterator. + + Args: + resource_path: path to load the hypothesis, candidate labels and, optionally, the smiles file. + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.resource_path = resource_path + self.model_path = os.path.join(self.resource_path, "model.pt") + self.config_path = os.path.join(self.resource_path, "config.pt") + self.vocab_path = os.path.join(self.resource_path, "vocab.pt") + self.n_samples = n_samples + self.n_batch = n_batch + self.max_len = max_len + self.organ_generator: OrganGenerator = None + + def generate_batch(self, target=None) -> List[Any]: + """Generate a batch of molecules. + + Args: + target: condition used for generation. + + Returns: + the generated molecules. + """ + if self.organ_generator is None: + optimiser = Organ( + model_path=self.model_path, + model_config_path=self.config_path, + vocab_path=self.vocab_path, + n_samples=self.n_samples, + n_batch=self.n_batch, + max_len=self.max_len, + ) + logger.info("Initialization of the Generator") + self.organ_generator = optimiser.get_generator() + molecules = self.organ_generator.generate(self.n_samples) + return molecules diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_ga.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_ga.py new file mode 100644 index 000000000..1b79a0e5b --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_ga.py @@ -0,0 +1,56 @@ +"""GRAPH GA implementation.""" + +from guacamol_baselines.graph_ga.goal_directed_generation import GB_GA_Generator + + +class GraphGA: + def __init__( + self, + smi_file, + mutation_rate: float, + population_size: int, + offspring_size: int, + n_jobs: int, + random_start: bool, + generations: int, + patience: int, + ): + """Initialize SMILESGA. + + Args: + smi_file: path where to load hypothesis, candidate labels and, optionally, the smiles file. + population_size: used with n_mutations for the initial generation of smiles within the population + n_jobs: number of concurrently running jobs + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + mutation_rate: frequency of the new mutations in a single gene or organism over time + offspring_size: number of molecules to select for new population + """ + self.smi_file = smi_file + self.mutation_rate = mutation_rate + self.population_size = population_size + self.offspring_size = offspring_size + self.n_jobs = n_jobs + self.random_start = random_start + self.generations = generations + self.patience = patience + + def get_generator(self) -> GB_GA_Generator: + """ + used for creating an instance of the GB_GA_Generator + + Returns: + An instance of GB_GA_Generator + """ + optimiser = GB_GA_Generator( + smi_file=self.smi_file, + population_size=self.population_size, + offspring_size=self.offspring_size, + mutation_rate=self.mutation_rate, + generations=self.generations, + n_jobs=self.n_jobs, + random_start=self.random_start, + patience=self.patience, + ) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_mcts.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_mcts.py new file mode 100644 index 000000000..2ff14eff0 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/graph_mcts.py @@ -0,0 +1,65 @@ +"""Graph MCTS implementation.""" + +import os + +from guacamol_baselines.graph_mcts import goal_directed_generation +from guacamol_baselines.graph_mcts.goal_directed_generation import GB_MCTS_Generator + + +class GraphMCTS: + def __init__( + self, + init_smiles: str, + population_size: int, + n_jobs: int, + generations: int, + patience: int, + num_sims: float, + max_children: int, + max_atoms: int, + pickle_directory: str = os.path.dirname( + os.path.realpath(goal_directed_generation.__file__) + ), + ): + """Initialize SMILESGA. + + Args: + init_smiles: path where to load hypothesis, candidate labels and, optionally, the smiles file. + population_size: used with n_mutations for the initial generation of smiles within the population + n_jobs: number of concurrently running jobs + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + num_sims: number of times to traverse the tree, + max_children: maximum number of childerns a node could have , + max_atoms: maximum number of atoms to explore to terminal the node state, + pickle_directory: path from where to load pickle files + """ + self.init_smiles = init_smiles + self.pickle_directory = pickle_directory + self.population_size = population_size + self.max_children = max_children + self.n_jobs = n_jobs + self.num_sims = num_sims + self.generations = generations + self.patience = patience + self.max_atoms = max_atoms + + def get_generator(self) -> GB_MCTS_Generator: + """ + used for creating an instance of the GB_MCTS_Generator + + Returns: + An instance of GB_MCTS_Generator + """ + optimiser = GB_MCTS_Generator( + pickle_directory=self.pickle_directory, + init_smiles=self.init_smiles, + population_size=self.population_size, + max_children=self.max_children, + num_sims=self.num_sims, + generations=self.generations, + n_jobs=self.n_jobs, + max_atoms=self.max_atoms, + patience=self.patience, + ) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_aae.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_aae.py new file mode 100644 index 000000000..3df64a652 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_aae.py @@ -0,0 +1,46 @@ +"""Moses AAE implementation.""" + +import argparse + +from guacamol_baselines.moses_baselines.aae_distribution_learning import AaeGenerator + + +class AAE: + def __init__( + self, + model_path: str, + model_config_path: str, + vocab_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize AAE. + + Args: + model_path: path from where to load the model + model_config_path: path from where to load the model config + vocab_path: path from where to load the vocab + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.parser = argparse.ArgumentParser() + self.parser.add_argument("--model_load", default=model_path) + self.parser.add_argument("--config_load", default=model_config_path) + self.parser.add_argument("--vocab_load", default=vocab_path) + self.parser.add_argument("--n_samples", default=n_samples) + self.parser.add_argument("--n_batch", default=n_batch) + self.parser.add_argument("--max_len", default=max_len) + self.parser.add_argument("--device", default="cpu") + self.config = self.parser.parse_known_args()[0] + + def get_generator(self) -> AaeGenerator: + """ + used for creating an instance of the AaeGenerator + + Returns: + An instance of AaeGenerator + """ + optimiser = AaeGenerator(self.config) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_organ.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_organ.py new file mode 100644 index 000000000..3bd859585 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_organ.py @@ -0,0 +1,48 @@ +"""Moses Organ implementation.""" + +import argparse + +from guacamol_baselines.moses_baselines.organ_distribution_learning import ( + OrganGenerator, +) + + +class Organ: + def __init__( + self, + model_path: str, + model_config_path: str, + vocab_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize Organ. + + Args: + model_path: path from where to load the model + model_config_path: path from where to load the model config + vocab_path: path from where to load the vocab + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.parser = argparse.ArgumentParser() + self.parser.add_argument("--model_load", default=model_path) + self.parser.add_argument("--config_load", default=model_config_path) + self.parser.add_argument("--vocab_load", default=vocab_path) + self.parser.add_argument("--n_samples", default=n_samples) + self.parser.add_argument("--n_batch", default=n_batch) + self.parser.add_argument("--max_len", default=max_len) + self.parser.add_argument("--device", default="cpu") + self.config = self.parser.parse_known_args()[0] + + def get_generator(self) -> OrganGenerator: + """ + used for creating an instance of the OrganGenerator + + Returns: + An instance of OrganGenerator + """ + optimiser = OrganGenerator(self.config) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_vae.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_vae.py new file mode 100644 index 000000000..96f36a982 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/moses_vae.py @@ -0,0 +1,51 @@ +"""Moses VAE implementation.""" + +import argparse +import logging + +from guacamol_baselines.moses_baselines.vae_distribution_learning import VaeGenerator + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class VAE: + def __init__( + self, + model_path: str, + model_config_path: str, + vocab_path: str, + n_samples: int, + n_batch: int, + max_len: int, + ): + """Initialize VAE. + + Args: + model_path: path from where to load the model + model_config_path: path from where to load the model config + vocab_path: path from where to load the vocab + n_samples: Number of samples to sample + n_batch: Size of the batch + max_len: Max length of SMILES + """ + self.parser = argparse.ArgumentParser() + self.parser.add_argument("--model_load", default=model_path) + self.parser.add_argument("--config_load", default=model_config_path) + self.parser.add_argument("--vocab_load", default=vocab_path) + self.parser.add_argument("--n_samples", default=n_samples) + self.parser.add_argument("--n_batch", default=n_batch) + self.parser.add_argument("--max_len", default=max_len) + self.parser.add_argument("--device", default="cpu") + self.config = self.parser.parse_known_args()[0] + + def get_generator(self) -> VaeGenerator: + """ + used for creating an instance of the VaeGenerator + + Returns: + An instance of VaeGenerator + """ + optimiser = VaeGenerator(self.config) + logger.debug(self.config) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_ga.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_ga.py new file mode 100644 index 000000000..dd7975a77 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_ga.py @@ -0,0 +1,56 @@ +"""SMILES GA implementation.""" + +from guacamol_baselines.smiles_ga.goal_directed_generation import ChemGEGenerator + + +class SMILESGA: + def __init__( + self, + smi_file, + population_size: int, + n_mutations: int, + n_jobs: int, + random_start: bool, + gene_size: int, + generations: int, + patience: int, + ): + """Initialize SMILESGA. + + Args: + smi_file: path where to load hypothesis, candidate labels and, optionally, the smiles file. + population_size: used with n_mutations for the initial generation of smiles within the population + n_mutations: used with population size for the initial generation of smiles within the population + n_jobs: number of concurrently running jobs + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + gene_size: size of the gene which is used in creation of genes + generations: number of evolutionary generations + patience: used for early stopping if population scores remains the same after generating molecules + """ + self.smi_file = smi_file + self.population_size = population_size + self.n_mutations = n_mutations + self.n_jobs = n_jobs + self.random_start = random_start + self.gene_size = gene_size + self.generations = generations + self.patience = patience + + def get_generator(self) -> ChemGEGenerator: + """ + used for creating an instance of ChemGEGenerator + + Returns: + An instance of ChemGEGenerator + """ + optimiser = ChemGEGenerator( + smi_file=self.smi_file, + population_size=self.population_size, + n_mutations=self.n_mutations, + generations=self.generations, + n_jobs=self.n_jobs, + random_start=self.random_start, + gene_size=self.gene_size, + patience=self.patience, + ) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_hc.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_hc.py new file mode 100644 index 000000000..f38e4478a --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_hc.py @@ -0,0 +1,70 @@ +"""Recurrent Neural Networks with Hill Climbing algorithm implementation.""" + +from guacamol_baselines.smiles_lstm_hc.goal_directed_generation import ( + SmilesRnnDirectedGenerator, +) + + +class SMILESLSTMHC: + def __init__( + self, + model_path: str, + smi_file, + max_len: int, + n_jobs: int, + keep_top: int, + n_epochs: int, + mols_to_sample: int, + optimize_n_epochs: int, + benchmark_num_samples: int, + optimize_batch_size: int, + random_start: bool, + ): + """Initialize SMILESLSTMHC. + Args: + model_path: path to load the model, + smi_file: path to load the hypothesis, candidate labels and, optionally, the smiles file, + max_len: maximum length of a SMILES string, + n_jobs: number of concurrently running jobs, + keep_top: molecules kept each step, + n_epochs: number of epochs to sample, + mols_to_sample: molecules sampled at each step, + optimize_n_epochs: number of epochs for the optimization, + benchmark_num_samples: number of molecules to generate from final model for the benchmark, + optimize_batch_size: batch size for the optimization, + random_start: set to True to randomly choose list of SMILES for generating optimizied molecules + , + """ + self.model_path = model_path + self.n_epochs = n_epochs + self.mols_to_sample = mols_to_sample + self.keep_top = keep_top + self.optimize_n_epochs = optimize_n_epochs + self.max_len = max_len + self.optimize_batch_size = optimize_batch_size + self.benchmark_num_samples = benchmark_num_samples + self.random_start = random_start + self.smi_file = smi_file + self.n_jobs = n_jobs + + def get_generator(self) -> SmilesRnnDirectedGenerator: + """ + used for creating an instance of the SmilesRnnDirectedGenerator + + Returns: + An instance of SmilesRnnDirectedGenerator + """ + optimiser = SmilesRnnDirectedGenerator( + pretrained_model_path=self.model_path, + n_epochs=self.n_epochs, + mols_to_sample=self.mols_to_sample, + keep_top=self.keep_top, + optimize_n_epochs=self.optimize_n_epochs, + max_len=self.max_len, + optimize_batch_size=self.optimize_batch_size, + number_final_samples=self.benchmark_num_samples, + random_start=self.random_start, + smi_file=self.smi_file, + n_jobs=self.n_jobs, + ) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_ppo.py b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_ppo.py new file mode 100644 index 000000000..46ed73b0b --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/guacamol/implementation/smiles_lstm_ppo.py @@ -0,0 +1,54 @@ +"""Recurrent Neural Networks with Proximal Policy Optimization algorithm implementation.""" + +from guacamol_baselines.smiles_lstm_ppo.goal_directed_generation import ( + PPODirectedGenerator, +) + + +class SMILESLSTMPPO: + def __init__( + self, + model_path: str, + num_epochs: int, + episode_size: int, + optimize_batch_size: int, + entropy_weight: int, + kl_div_weight: int, + clip_param: float, + ): + """Initialize SMILESLSTMPPO. + + Args: + model_path: path to load the model, + num_epochs: number of epochs to sample + episode_size: number of molecules sampled by the policy at the start of a series of ppo updates + optimize_batch_size: batch size for the optimization + entropy_weight: used for calculating entropy loss + kl_div_weight: used for calculating Kullback-Leibler divergence loss + clip_param: used for determining how far the new policy is from the old one + """ + self.model_path = model_path + self.num_epochs = num_epochs + self.episode_size = episode_size + self.optimize_batch_size = optimize_batch_size + self.entropy_weight = entropy_weight + self.kl_div_weight = kl_div_weight + self.clip_param = clip_param + + def get_generator(self) -> PPODirectedGenerator: + """ + used for creating an instance of the PPODirectedGenerator + + Returns: + An instance of PPODirectedGenerator + """ + optimiser = PPODirectedGenerator( + pretrained_model_path=None, + num_epochs=self.num_epochs, + episode_size=self.episode_size, + batch_size=self.optimize_batch_size, + entropy_weight=self.entropy_weight, + kl_div_weight=self.kl_div_weight, + clip_param=self.clip_param, + ) + return optimiser diff --git a/src/gt4sd/algorithms/conditional_generation/key_bert/__init__.py b/src/gt4sd/algorithms/conditional_generation/key_bert/__init__.py new file mode 100644 index 000000000..4d86a86b5 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/key_bert/__init__.py @@ -0,0 +1,5 @@ +"""Keyword generation via BERT models initialization.""" + +from .core import KeyBERTGenerator, KeywordBERTGenerationAlgorithm + +__all__ = ["KeywordBERTGenerationAlgorithm", "KeyBERTGenerator"] diff --git a/src/gt4sd/algorithms/conditional_generation/key_bert/core.py b/src/gt4sd/algorithms/conditional_generation/key_bert/core.py new file mode 100644 index 000000000..6c2d0fa83 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/key_bert/core.py @@ -0,0 +1,180 @@ +"""Algortihms for keyword generation using BERT models.""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, Set, TypeVar + +from ...core import ( + AlgorithmConfiguration, + GeneratorAlgorithm, + get_configuration_class_with_attributes, +) +from ...registry import ApplicationsRegistry +from .implementation import KeyBERT + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=Any) +Targeted = Callable[[T], Iterable[Any]] + + +class KeywordBERTGenerationAlgorithm(GeneratorAlgorithm[S, T]): + """Topics prediction algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate KeywordBERTGenerationAlgorithm ready to predict topics. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for predicting topics for a given text:: + + config = KeyBERTGenerator() + algorithm = KeywordBERTGenerationAlgorithm(configuration=config, target="This is a text I want to understand better") + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to perform the prediction via KeywordBERTGenerationAlgorithm's generator. + + Args: + configuration: helps to set up specific application of KeywordBERTGenerationAlgorithm. + target: context or condition for the generation. + + Returns: + callable with target generating keywords sorted by relevance. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Any = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.predict + + +@ApplicationsRegistry.register_algorithm_application(KeywordBERTGenerationAlgorithm) +class KeyBERTGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate keywords. + + If the model is not found in the cache, models are collected from https://www.sbert.net/docs/pretrained_models.html. + distilbert-base-nli-stsb-mean-tokens is recommended for english, while xlm-r-bert-base-nli-stsb-mean-tokens for all + other languages as it support 100+ languages. + """ + + algorithm_name: ClassVar[str] = KeywordBERTGenerationAlgorithm.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "nlp" + algorithm_version: str = "distilbert-base-nli-mean-tokens" + + minimum_keyphrase_ngram: int = field( + default=1, + metadata=dict(description=("Lower bound for phrase size.")), + ) + maximum_keyphrase_ngram: int = field( + default=2, + metadata=dict(description=("Upper bound for phrase size.")), + ) + stop_words: str = field( + default="english", + metadata=dict(description=("Language for the stop words removal.")), + ) + top_n: int = field( + default=10, + metadata=dict(description=("Number of keywords to extract.")), + ) + use_maxsum: bool = field( + default=False, + metadata=dict( + description=("Control usage of max sum similarity for keywords generated.") + ), + ) + use_mmr: bool = field( + default=False, + metadata=dict( + description=( + "Control usage of max marginal relevance for keywords generated." + ) + ), + ) + diversity: float = field( + default=0.5, + metadata=dict(description=("Diversity for the results when enabling use_mmr.")), + ) + number_of_candidates: int = field( + default=20, + metadata=dict(description=("Candidates considered when enabling use_maxsum.")), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Text to analyze", + "description": "Text considered for the keyword generation task.", + "type": "string", + } + + def get_conditional_generator(self, resources_path: str) -> KeyBERT: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return KeyBERT( + resources_path=resources_path, + minimum_keyphrase_ngram=self.minimum_keyphrase_ngram, + maximum_keyphrase_ngram=self.maximum_keyphrase_ngram, + stop_words=self.stop_words, + top_n=self.top_n, + use_maxsum=self.use_maxsum, + use_mmr=self.use_mmr, + diversity=self.diversity, + number_of_candidates=self.number_of_candidates, + model_name=self.algorithm_version, + ) + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://www.sbert.net/docs/pretrained_models.html" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) diff --git a/src/gt4sd/algorithms/conditional_generation/key_bert/implementation.py b/src/gt4sd/algorithms/conditional_generation/key_bert/implementation.py new file mode 100644 index 000000000..4151d7c10 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/key_bert/implementation.py @@ -0,0 +1,97 @@ +"""Implementation of the KeyBERT keyword extractor.""" + +import os +from typing import List, Optional, Union + +import torch +from keybert import KeyBERT as KeyBERTCore +from sentence_transformers import SentenceTransformer + +from ....frameworks.torch import device_claim + + +class KeyBERT: + """ + Keyword extractor based on [KeyBERT](https://github.com/MaartenGr/KeyBERT). + """ + + def __init__( + self, + resources_path: str, + minimum_keyphrase_ngram: int, + maximum_keyphrase_ngram: int, + stop_words: Optional[str], + top_n: int, + use_maxsum: bool, + use_mmr: bool, + diversity: float, + number_of_candidates: int, + model_name: str, + device: Optional[Union[torch.device, str]] = None, + ): + """Initialize KeyBERT. + + Args: + resources_path: path where to load hypothesis, candidate labels and, optionally, the model. + minimum_keyphrase_ngram: lower bound for phrase size. + maximum_keyphrase_ngram: upper bound for phrase size. + stop_words: language for the stop words removal. If not provided, no stop words removal. + top_n: number of keywords to extract. + use_maxsum: control usage of max sum similarity for keywords generated. + use_mmr: control usage of max marginal relevance for keywords generated. + diversity: diversity for the results when enabling use_mmr. + number_of_candidates: candidates considered when enabling use_maxsum. + model_name: name of the model to load from the cache or download from SentenceTransformers. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + self.device = device_claim(device) + self.resources_path = resources_path + self.minimum_keyphrase_ngram = minimum_keyphrase_ngram + self.maximum_keyphrase_ngram = maximum_keyphrase_ngram + self.stop_words = stop_words + self.top_n = top_n + self.use_maxsum = use_maxsum + self.use_mmr = use_mmr + self.diversity = diversity + self.number_of_candidates = number_of_candidates + self.model_name = model_name + self.load_model() + + def load_model(self) -> None: + """Load KeyBERT model.""" + if ( + os.path.exists(self.resources_path) + and len(os.listdir(self.resources_path)) > 0 + ): + model_name_or_path = self.resources_path + else: + model_name_or_path = self.model_name + sentence_model = SentenceTransformer(model_name_or_path, device=self.device) + self.model = KeyBERTCore(model=sentence_model) + + def predict(self, text: str) -> List[str]: + """Get keywords sorted by relevance. + + Args: + text: text to extract keywords from. + + Returns: + keywords sorted by score from highest to lowest. + """ + return [ + keyword + for keyword, _ in self.model.extract_keywords( + text, + keyphrase_ngram_range=( + self.minimum_keyphrase_ngram, + self.maximum_keyphrase_ngram, + ), + stop_words=self.stop_words, + top_n=self.top_n, + use_maxsum=self.use_maxsum, + use_mmr=self.use_mmr, + diversity=self.diversity, + nr_candidates=self.number_of_candidates, + ) + ] diff --git a/src/gt4sd/algorithms/conditional_generation/paccmann_rl/__init__.py b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/__init__.py new file mode 100644 index 000000000..230528803 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/__init__.py @@ -0,0 +1,13 @@ +"""PaccMannRL initialization.""" + +from .core import ( + PaccMannRL, + PaccMannRLOmicBasedGenerator, + PaccMannRLProteinBasedGenerator, +) + +__all__ = [ + "PaccMannRL", + "PaccMannRLProteinBasedGenerator", + "PaccMannRLOmicBasedGenerator", +] diff --git a/src/gt4sd/algorithms/conditional_generation/paccmann_rl/core.py b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/core.py new file mode 100644 index 000000000..02d683bb7 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/core.py @@ -0,0 +1,269 @@ +"""PaccMann\\ :superscript:`RL` Algorithm. + +PaccMann\\ :superscript:`RL` generation is conditioned via reinforcement learning. +""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from typing_extensions import Protocol, runtime_checkable + +from ....domains.materials import Omics, Protein, SmallMolecule +from ....exceptions import InvalidItem +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import ( + ConditionalGenerator, + ProteinSequenceConditionalGenerator, + TranscriptomicConditionalGenerator, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", Protein, Omics) +S = TypeVar("S", bound=SmallMolecule) +Targeted = Callable[[T], Iterable[Any]] + + +class PaccMannRL(GeneratorAlgorithm[S, T]): + """PaccMann\\ :superscript:`RL` Algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate PaccMannRL ready to generate items. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating small molecules (SMILES) with high affinity + for a target protein:: + + affinity_config = PaccMannRLProteinBasedGenerator() + target = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT" + paccmann_affinity = PaccMannRL(configuration=affinity_config, target=target) + items = list(paccmann_affinity.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to sample batches via PaccMannRL's ConditionalGenerator. + + Args: + configuration: helps to set up specific application of PaccMannRL. + target: context or condition for the generation. + + Returns: + callable with target generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: ConditionalGenerator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_batch + + def validate_configuration( + self, configuration: AlgorithmConfiguration[S, T] + ) -> AlgorithmConfiguration[S, T]: + @runtime_checkable + class AnyPaccMannRLConfiguration(Protocol): + """Protocol for PaccMannRL configurations.""" + + def get_conditional_generator( + self, resources_path: str + ) -> ConditionalGenerator: + ... + + def validate_item(self, item: Any) -> S: + ... + + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AnyPaccMannRLConfiguration) + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(PaccMannRL) +class PaccMannRLProteinBasedGenerator(AlgorithmConfiguration[SmallMolecule, Protein]): + """ + Configuration to generate compounds with high affinity to a target protein. + + Implementation from the paper: https://doi.org/10.1088/2632-2153/abe808. + """ + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Target protein sequence", + "description": "AA sequence of the protein target to generate non-toxic ligands against.", + "type": "string", + } + + def get_conditional_generator( + self, resources_path: str + ) -> ProteinSequenceConditionalGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return ProteinSequenceConditionalGenerator( + resources_path=resources_path, + temperature=self.temperature, + generated_length=self.generated_length, + samples_per_protein=self.batch_size, + ) + + def validate_item(self, item: str) -> SmallMolecule: + """Check that item is a valid SMILES. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated SMILES. + """ + ( + molecules, + _, + ) = ProteinSequenceConditionalGenerator.validate_molecules([item]) + if molecules[0] is None: + raise InvalidItem( + title="InvalidSMILES", + detail=f'rdkit.Chem.MolFromSmiles returned None for "{item}"', + ) + return SmallMolecule(item) + + +@ApplicationsRegistry.register_algorithm_application(PaccMannRL) +class PaccMannRLOmicBasedGenerator(AlgorithmConfiguration[SmallMolecule, Omics]): + """ + Configuration to generate compounds with low IC50 for a target omics profile. + + Implementation from the paper: https://doi.org/10.1016/j.isci.2021.102269. + """ + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Gene expression profile", + "description": "A gene expression profile to generate effective molecules against.", + "type": "list", + } + + def get_conditional_generator( + self, resources_path: str + ) -> TranscriptomicConditionalGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return TranscriptomicConditionalGenerator( + resources_path=resources_path, + temperature=self.temperature, + generated_length=self.generated_length, + samples_per_profile=self.batch_size, + ) + + def validate_item(self, item: str) -> SmallMolecule: + """Check that item is a valid SMILES. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated SMILES. + """ + ( + molecules, + _, + ) = TranscriptomicConditionalGenerator.validate_molecules([item]) + if molecules[0] is None: + raise InvalidItem( + title="InvalidSMILES", + detail=f'rdkit.Chem.MolFromSmiles returned None for "{item}"', + ) + return SmallMolecule(item) diff --git a/src/gt4sd/algorithms/conditional_generation/paccmann_rl/implementation.py b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/implementation.py new file mode 100644 index 000000000..dcfa0892c --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/paccmann_rl/implementation.py @@ -0,0 +1,382 @@ +"""Implementation of PaccMann^RL conditional generators.""" + +import json +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Set, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from rdkit import Chem +from paccmann_chemistry.models import StackGRUDecoder, StackGRUEncoder, TeacherVAE +from paccmann_chemistry.utils.search import SamplingSearch +from paccmann_omics.encoders import ENCODER_FACTORY +from pytoda.smiles.smiles_language import SMILESLanguage + +from ....domains.materials import validate_molecules +from ....domains.materials.protein_encoding import PrimarySequenceEncoder +from ....frameworks.torch import device_claim +from ....frameworks.torch.vae import reparameterize + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class ConditionalGenerator(ABC): + """Abstract interface for a conditional generator.""" + + #: device where the inference is running. + device: torch.device + #: temperature for the sampling. + temperature: float + #: maximum length of the generated molecules. + generated_length: int + + #: parameters for the SELFIES generator. + selfies_conditional_generator_params: dict + #: SELFIES generator. + selfies_conditional_generator: TeacherVAE + #: SMILES language instance. + smiles_language: SMILESLanguage + + generator_latent_size: int + encoder_latent_size: int + + def get_smiles_from_latent(self, latent: torch.Tensor) -> List[str]: + """Take samples from the latent space. + + Args: + latent: latent vector tensor. + + Returns: + SMILES list and indexes for the valid ones. + """ + if self.generator_latent_size == 2 * self.encoder_latent_size: + latent = latent.repeat(1, 1, 2) + + # generate molecules as tokens list + generated_molecules = self.selfies_conditional_generator.generate( + latent, + prime_input=torch.tensor( + [self.smiles_language.start_index], device=self.device + ).long(), + end_token=torch.tensor( + [self.smiles_language.stop_index], device=self.device + ).long(), + generate_len=self.generated_length, + search=SamplingSearch(temperature=self.temperature), + ) + + # decode SELFIES + selfies = [ + self.smiles_language.token_indexes_to_smiles(generated_molecule.tolist()) + for generated_molecule in iter(generated_molecules) + ] + + # convert SELFIES to SMILES + smiles = [ + self.smiles_language.selfies_to_smiles(a_selfies) for a_selfies in selfies + ] + return smiles + + @staticmethod + def validate_molecules(smiles) -> Tuple[List[Chem.rdchem.Mol], List[int]]: + return validate_molecules(smiles_list=smiles) + + @abstractmethod + def get_latent(self, condition: Any) -> torch.Tensor: + pass + + @abstractmethod + def generate_valid(self, condition: Any, number_of_molecules: int) -> List[str]: + """ + Generate a given number of samples (molecules) from a given condition. + + Args: + protein: the protein used as context/condition. + number_of_molecules: number of molecules to sample. + + Returns: + list of SMILES generated. + """ + # prepare the molecule set + generated_molecules: Set[str] = set() + logger.info("embedding condition and getting reparametrized latent samples") + latent = self.get_latent(condition) + logger.info("starting generation of molecules") + while len(generated_molecules) < number_of_molecules: + # generate the molecules + generated_smiles = self.get_smiles_from_latent(latent) + _, valid_ids = self.validate_molecules(generated_smiles) + generated_molecules |= set([generated_smiles[index] for index in valid_ids]) + logger.info("completed generation of molecules") + # return the molecules listed by length + return sorted(list(generated_molecules), key=len, reverse=True)[ + :number_of_molecules + ] + + def generate_batch(self, condition: Any) -> List[str]: + logger.info("embedding condition and getting reparametrized latent samples") + latent = self.get_latent(condition) + logger.info("starting generation of molecules") + # generate the molecules + return self.get_smiles_from_latent(latent) + + +class ProteinSequenceConditionalGenerator(ConditionalGenerator): + """ + Protein conditional generator as implemented in https://doi.org/10.1088/2632-2153/abe808 + (originally https://arxiv.org/abs/2005.13285). + It generates highly binding and low toxic ligands. + + Attributes: + samples_per_protein: number of points sampled per protein. + It has to be greater than 1. + protein_embedding_encoder_params: parameter for the protein embedding encoder. + protein_embedding_encoder: protein embedding encoder. + """ + + def __init__( + self, + resources_path: str, + temperature: float = 1.4, + generated_length: int = 100, + samples_per_protein: int = 100, + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """ + Initialize the generator. + + Args: + resources_path: directory where to find models and parameters. + temperature: temperature for the sampling. Defaults to 1.4. + generated_length: maximum length of the generated molecules. + Defaults to 100. + samples_per_protein: number of points sampled per protein. + It has to be greater than 1. Defaults to 10. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + # device + self.device = device_claim(device) + # setting sampling parameters + self.temperature = temperature + self.generated_length = generated_length + self.samples_per_protein = samples_per_protein + # instantiate protein embedding encoder + with open(os.path.join(resources_path, "protein_embedding_params.json")) as fp: + self.protein_embedding_encoder_params = json.load(fp) + self.protein_embedding_encoder = ENCODER_FACTORY["dense"]( + self.protein_embedding_encoder_params + ).to(self.device) + self.protein_embedding_encoder.load( + os.path.join(resources_path, "protein_embedding_encoder.pt"), + map_location=self.device, + ) + self.protein_embedding_encoder.eval() + self.encoder_latent_size = self.protein_embedding_encoder.latent_size + # instantiate selfies conditional generator + with open( + os.path.join(resources_path, "selfies_conditional_generator.json") + ) as fp: + self.selfies_conditional_generator_params = json.load(fp) + self.selfies_conditional_generator = TeacherVAE( + StackGRUEncoder(self.selfies_conditional_generator_params), + StackGRUDecoder(self.selfies_conditional_generator_params), + ).to(self.device) + self.selfies_conditional_generator.load( + os.path.join(resources_path, "selfies_conditional_generator.pt"), + map_location=self.device, + ) + self.selfies_conditional_generator.eval() + self.generator_latent_size = ( + self.selfies_conditional_generator.decoder.latent_dim + ) + # loading SMILES language for decoding and conversion of SELFIES to SMILES + self.smiles_language = SMILESLanguage.load( + os.path.join(resources_path, "selfies_language.pkl") + ) + # protein embedding from primary sequence (via tape) + self.primary_sequence_embedder = PrimarySequenceEncoder( + model_type="transformer", + from_pretrained="bert-base", + model_config_file=None, + tokenizer="iupac", + ).to(self.device) + + def get_latent(self, protein: str) -> torch.Tensor: + """ + Given a protein generate the latent representation. + + Args: + protein: the protein used as context/condition. + + Returns: + the latent representation for the given context. It contains + self.samples_per_protein repeats. + """ + # encode embedded sequence once, ignore the returned dummy ids + embeddings, _ = self.primary_sequence_embedder.forward([[protein]]) + protein_mu, protein_logvar = self.protein_embedding_encoder( + embeddings.to(self.device) + ) + + # now stack as batch to generate different samples + proteins_mu = torch.cat([protein_mu] * self.samples_per_protein, dim=0) + proteins_logvar = torch.cat([protein_logvar] * self.samples_per_protein, dim=0) + # get latent representation + return torch.unsqueeze(reparameterize(proteins_mu, proteins_logvar), 0) + + def generate_valid(self, protein: str, number_of_molecules: int) -> List[str]: + """ + Generate a given number of samples (molecules) from a given protein. + + Args: + protein: the protein used as context/condition. + number_of_molecules: number of molecules to sample. + + Returns: + list of SMILES generated. + """ + return super().generate_valid( + condition=protein, number_of_molecules=number_of_molecules + ) + + def generate_batch(self, protein: str) -> List[str]: + return super().generate_batch(condition=protein) + + +class TranscriptomicConditionalGenerator(ConditionalGenerator): + """ + Transcriptomic conditional generator as implemented in https://doi.org/10.1016/j.isci.2021.102269 + (originally https://doi.org/10.1007/978-3-030-45257-5_18, https://arxiv.org/abs/1909.05114). + It generates highly effective small molecules against transcriptomic progiles. + + Attributes: + samples_per_profile: number of points sampled per profile. + It has to be greater than 1. + transcriptomic_encoder_params: parameter for the protein embedding encoder. + transcriptomic_encoder: protein embedding encoder. + """ + + def __init__( + self, + resources_path: str, + temperature: float = 1.4, + generated_length: int = 100, + samples_per_profile: int = 100, + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """ + Initialize the generator. + + Args: + resources_path: directory where to find models and parameters. + temperature: temperature for the sampling. Defaults to 1.4. + generated_length: maximum length of the generated molecules. + Defaults to 100. + samples_per_profile: number of points sampled per protein. + It has to be greater than 1. Defaults to 10. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + # device + self.device = device_claim(device) + # setting sampling parameters + self.temperature = temperature + self.generated_length = generated_length + self.samples_per_profile = samples_per_profile + with open(os.path.join(resources_path, "genes.txt")) as fp: + self.genes = [gene.strip() for gene in fp if gene] + # instantiate protein embedding encoder + with open(os.path.join(resources_path, "transcriptomic_params.json")) as fp: + self.transcriptomic_encoder_params = json.load(fp) + self.transcriptomic_encoder = ENCODER_FACTORY["dense"]( + self.transcriptomic_encoder_params + ).to(self.device) + self.transcriptomic_encoder.load( + os.path.join(resources_path, "transcriptomic_encoder.pt"), + map_location=self.device, + ) + self.transcriptomic_encoder.eval() + self.encoder_latent_size = self.transcriptomic_encoder.latent_size + # instantiate selfies conditional generator + with open( + os.path.join(resources_path, "selfies_conditional_generator.json") + ) as fp: + self.selfies_conditional_generator_params = json.load(fp) + self.selfies_conditional_generator = TeacherVAE( + StackGRUEncoder(self.selfies_conditional_generator_params), + StackGRUDecoder(self.selfies_conditional_generator_params), + ).to(self.device) + self.selfies_conditional_generator.load( + os.path.join(resources_path, "selfies_conditional_generator.pt"), + map_location=self.device, + ) + self.selfies_conditional_generator.eval() + self.generator_latent_size = ( + self.selfies_conditional_generator.decoder.latent_dim + ) + # loading SMILES language for decoding and conversion of SELFIES to SMILES + self.smiles_language = SMILESLanguage.load( + os.path.join(resources_path, "selfies_language.pkl") + ) + + def get_latent(self, profile: Union[np.ndarray, pd.Series, str]) -> torch.Tensor: + """ + Given a profile generate the latent representation. + + Args: + profile: the profile used as context/condition. + + Raises: + ValueError: in case the profile has a size mismatch with the genes panel. + + Returns: + the latent representation for the given context. It contains + self.samples_per_profile repeats. + """ + if isinstance(profile, pd.Series): + # make sure genes are sorted + profile = profile[self.genes].values + elif isinstance(profile, str): + logger.warning("profile passed as string, serializing it to a list") + profile = np.array(json.loads(profile)) + if profile.size != len(self.genes): + raise ValueError( + f"provided profile size ({profile.size}) does not match required size {len(self.genes)}" + ) + # encode embedded progiles + transcriptomic_mu, transcriptomic_logvar = self.transcriptomic_encoder( + torch.from_numpy( + np.vstack([profile] * self.samples_per_profile), + ) + .float() + .to(self.device) + ) + # get latent representation + return torch.unsqueeze( + reparameterize(transcriptomic_mu, transcriptomic_logvar), 0 + ) + + def generate_valid( + self, profile: Union[np.ndarray, pd.Series], number_of_molecules: int + ) -> List[str]: + """ + Generate a given number of samples (molecules) from a given transcriptomic profile. + + Args: + profile: the profile used as context/condition. + number_of_molecules: number of molecules to sample. + + Returns: + list of SMILES generated. + """ + return super().generate_valid( + condition=profile, number_of_molecules=number_of_molecules + ) + + def generate_batch(self, profile: Union[np.ndarray, pd.Series]) -> List[str]: + return super().generate_batch(condition=profile) diff --git a/src/gt4sd/algorithms/conditional_generation/regression_transformer/__init__.py b/src/gt4sd/algorithms/conditional_generation/regression_transformer/__init__.py new file mode 100644 index 000000000..ce59b819b --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/regression_transformer/__init__.py @@ -0,0 +1,12 @@ +"""Regression Transformer initialization.""" +from .core import ( + RegressionTransformer, + RegressionTransformerMolecules, + RegressionTransformerProteins, +) + +__all__ = [ + "RegressionTransformer", + "RegressionTransformerMolecules", + "RegressionTransformerProteins", +] diff --git a/src/gt4sd/algorithms/conditional_generation/regression_transformer/core.py b/src/gt4sd/algorithms/conditional_generation/regression_transformer/core.py new file mode 100644 index 000000000..2171ce432 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/regression_transformer/core.py @@ -0,0 +1,352 @@ +"""RegressionTransformer algorithm. + +RegressionTransformer is a mutlitask regression and conditional generation model. +""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar, Union + +from typing_extensions import Protocol, runtime_checkable + +from ....domains.materials import Molecule, Property, Sequence +from ....exceptions import InvalidItem +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import ChemicalLanguageRT, ConditionalGenerator, ProteinLanguageRT + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Sequence) +S = TypeVar("S", Property, Molecule) +Targeted = Callable[[T], Iterable[Any]] + + +class RegressionTransformer(GeneratorAlgorithm[S, T]): + """RegressionTransformer Algorithm.""" + + #: The maximum number of samples a user can try to run in one go + max_samples: int = 50 + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> None: + """Instantiate Regression Transformer ready to generate items. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating small molecules (SMILES) with high affinity for a target protein:: + + config = RegressionTransformerProteins( + search='sample', temperature=2.0, tolerance=10 + ) + target = "0.393|GSQEVNSGT[MASK][MASK][MASK]YKNASPEEAE[MASK][MASK]IARKAGATTWTEKGNKWEIRI" + stability_generator = RegressionTransformer(configuration=config, target=target) + items = list(stability_generator.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + + # No validation/check on the target input here, since model is not yet loaded. + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to sample with the given configuration. + + Args: + configuration: helps to set up specific application of PaccMannRL. + target: context or condition for the generation. + + Returns: + callable with target generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: ConditionalGenerator = configuration.get_conditional_generator( # type: ignore + resources_path=self.local_artifacts, context=target + ) + if implementation.task == "regression" and configuration.search == "greedy": # type: ignore + self.max_samples = 1 + logger.warning( + "max_samples was set to 1 due to regression task and greedy search" + ) + + return implementation.generate_batch # type: ignore + + def validate_configuration( + self, configuration: AlgorithmConfiguration[S, T] + ) -> AlgorithmConfiguration[S, T]: + @runtime_checkable + class AnyRegressionTransformerConfiguration(Protocol): + """Protocol for RegressionTransformer configurations.""" + + def get_conditional_generator( + self, resources_path: str + ) -> ConditionalGenerator: + ... + + def validate_item(self, item: Any) -> S: + ... + + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AnyRegressionTransformerConfiguration) + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(RegressionTransformer) +class RegressionTransformerMolecules(AlgorithmConfiguration[Sequence, Sequence]): + """ + Configuration to generate molecules given a continuous property target and a molecular sub-structure. + + Implementation from the paper: https://arxiv.org/abs/2202.01338. + + Examples: + An example for generating a peptide around a desired property value:: + + config = RegressionTransformerMolecules( + search='sample', temperature=2, tolerance=5 + ) + target = "-3.534|[Br][C][=C][C][MASK][MASK][=C][C][=C][C][=C][Ring1][MASK][MASK][Branch2_3][Ring1][Branch1_2]" + esol_generator = RegressionTransformer( + configuration=config, target=target + ) + list(esol_generator.sample(5)) + + An example for predicting the solubility of a molecule:: + + config = RegressionTransformerMolecules(search='greedy') + target = "[MASK][MASK][MASK][MASK][MASK]|[Cl][C][Branch1_2][Branch1_2][=C][Branch1_1][C][Cl][Cl][Cl]" + esol_generator = RegressionTransformer( + configuration=config, target=target + ) + list(stability_generator.sample(1)) + """ + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + search: str = field( + default="sample", + metadata=dict( + description="Search algorithm to use for the generation: sample or greedy" + ), + ) + + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + batch_size: int = field( + default=8, + metadata=dict(description="Batch size for the conditional generation"), + ) + tolerance: float = field( + default=20.0, + metadata=dict( + description="Precision tolerance for the conditional generation task. Given in percent" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Masked input sequence", + "description": ( + "A sequence with a property value and a SELFIES string. Masking can either occur on the property or on the SELFIES, but not both." + "For the scale of the property values, please see the task/dataset." + ), + "type": "string", + } + + def get_conditional_generator( + self, resources_path: str, context: str + ) -> ChemicalLanguageRT: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + self.generator = ChemicalLanguageRT( + resources_path=resources_path, + context=context, + search=self.search, + temperature=self.temperature, + batch_size=self.batch_size, + tolerance=self.tolerance, + ) + return self.generator + + def validate_item(self, item: str) -> Union[Molecule, Property]: # type: ignore + """Check that item is a valid sequence. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated item. + """ + if item is None: + raise InvalidItem(title="InvalidSequence", detail="Sequence is None") + ( + items, + _, + ) = self.generator.validate_output([item]) + if items[0] is None: + if self.generator.task == "generation": + title = "InvalidSMILES" + detail = f'rdkit.Chem.MolFromSmiles returned None for "{item}"' + else: + title = "InvalidNumerical" + detail = f'"{item}" is not a valid floating point number' + raise InvalidItem(title=title, detail=detail) + return item + + +@ApplicationsRegistry.register_algorithm_application(RegressionTransformer) +class RegressionTransformerProteins(AlgorithmConfiguration[Sequence, Sequence]): + """ + Configuration to generate protein given a continuous property target and a partial AAs. + + Implementation from the paper: https://arxiv.org/abs/2202.01338. It can also predict the property given a full sequence. + + Examples: + An example for generating a peptide around a desired property value:: + + config = RegressionTransformerProteins( + search='sample', temperature=2, tolerance=5 + ) + target = "1.1234|TTIKNG[MASK][MASK][MASK]YTVPLSPEQAAK[MASK][MASK][MASK]KKRWPDYEVQIHGNTVKVT" + stability_generator = RegressionTransformer( + configuration=config, target=target + ) + list(stability_generator.sample(5)) + + An example for predicting the stability of a peptide:: + + config = RegressionTransformerProteins(search='greedy') + target = "[MASK][MASK][MASK][MASK][MASK]|GSQEVNSNASPEEAEIARKAGATTWTEKGNKWEIRI" + stability_generator = RegressionTransformer( + configuration=config, target=target + ) + list(stability_generator.sample(1)) + """ + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + search: str = field( + default="sample", + metadata=dict( + description="Search algorithm to use for the generation: sample or greedy" + ), + ) + + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + batch_size: int = field( + default=32, + metadata=dict(description="Batch size for the conditional generation"), + ) + tolerance: float = field( + default=20.0, + metadata=dict( + description="Precision tolerance for the conditional generation task. Given in percent" + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Masked input sequence", + "description": "A sequence with a property value and an AAS. Masking can either occur on the property or on the AAS, but not both.", + "type": "string", + } + + def get_conditional_generator( + self, resources_path: str, context: str + ) -> ProteinLanguageRT: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + context: input sequence to be used for the generation. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + + self.generator = ProteinLanguageRT( + resources_path=resources_path, + search=self.search, + temperature=self.temperature, + context=context, + batch_size=self.batch_size, + tolerance=self.tolerance, + ) + return self.generator + + def validate_item(self, item: str) -> Union[Molecule, Property]: # type: ignore + """Check that item is a valid sequence. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated item. + """ + if item is None: + raise InvalidItem(title="InvalidSequence", detail="Sequence is None") + ( + items, + _, + ) = self.generator.validate_output([item]) + if items[0] is None: + if self.generator.task == "generation": + title = "InvalidSequence" + detail = f'"{item}" does not adhere to IUPAC convention for AAS' + else: + title = "InvalidNumerical" + detail = f'"{item}" is not a valid floating point number' + raise InvalidItem(title=title, detail=detail) + return item diff --git a/src/gt4sd/algorithms/conditional_generation/regression_transformer/implementation.py b/src/gt4sd/algorithms/conditional_generation/regression_transformer/implementation.py new file mode 100644 index 000000000..f3f039d6c --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/regression_transformer/implementation.py @@ -0,0 +1,589 @@ +"""Implementation of Regression Transformer conditional generators.""" +import json +import logging +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from terminator.collators import MaskedTextCollator, PropertyCollator +from terminator.inference import InferenceRT +from terminator.search import SEARCH_FACTORY, Search +from terminator.selfies import decoder +from terminator.tokenization import InferenceBertTokenizer +from transformers import AutoConfig, AutoModelWithLMHead, XLNetLMHeadModel + +from ....domains.materials import Property, Sequence, validate_molecules +from ....frameworks.torch import device_claim + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class ConditionalGenerator: + """Main interface for a regression transformer.""" + + # device where the inference is running. + device: torch.device + + # The task the RT is currently performing. Either 'regression' or 'generation' + task: str + + # method to convert logits to predictions. Either GreedySearch or SamplingSearch + search: Search + + # ***** Additional attributes for text generation ***** + # data collator for property prediction of self-generated items + property_collator: PropertyCollator + + # percentage of tolerated deviation between desired and obtained property + tolerance: float = 20 + + # number of samples obtained per call + batch_size: int = 8 + + def __init__( + self, resources_path: str, device: Optional[Union[torch.device, str]] = None + ) -> None: + """ + Initialize the generator. + + Args: + resources_path: directory where to find models and parameters. + temperature: temperature for the sampling. Defaults to 1.4. + generated_length: maximum length of the generated molecules. + Defaults to 100. + samples_per_protein: number of points sampled per protein. + It has to be greater than 1. Defaults to 10. + device: device where the inference is running either as a dedicated class + or a string. If not provided is inferred. + """ + # device + self.device = device_claim(device) + + # Set up the data preparation pipeline + self.tokenizer = InferenceBertTokenizer.from_pretrained( + resources_path, pad_even=False + ) + self.collator = MaskedTextCollator(self.tokenizer) + + # Set up model: First load the pretrained XLNet model + xlnet_model, config = self.load_model(resources_path) + # Initialize the custom RT model + self.model = InferenceRT(xlnet_model, self.tokenizer, config) + + # Set up inference parameters + self.load_inference(resources_path) + + def load_model(self, resources_path: str) -> Tuple[XLNetLMHeadModel, Any]: + """ + Loading a XLNetLMHeadModel which constitutes the base of a RT model. + + Args: + resources_path: path to the model. + + Returns: + XLNetLMHeadModel: base of a Regression Transformer model. + XLNetConfig: configuration of the model. + """ + config_name = os.path.join(resources_path, "config.json") + config = AutoConfig.from_pretrained(config_name, mem_len=1024) + xlnet_model = AutoModelWithLMHead.from_pretrained( + resources_path, from_tf=False, config=config + ) + xlnet_model.resize_token_embeddings(len(self.tokenizer)) + xlnet_model.to(self.device) + xlnet_model.eval() + logger.info(f"Model restored from {resources_path}") + return xlnet_model, config + + def load_inference(self, resources_path: str) -> None: + """ + Load and set up all parameters necessary for inference. + + Args: + resources_path: path to the model folder. + """ + try: + with open(os.path.join(resources_path, "inference.json"), "r") as f: + data = json.load(f) + self.property = data["property_token"] + self.property_mask_length = data["property_mask_length"][self.property] + self.min_ = data.get("property_ranges", {}).get(self.property, [0, 1])[0] + self.max_ = data.get("property_ranges", {}).get(self.property, [0, 1])[1] + self.metadata = data + except Exception: + raise ValueError( + f"Could not restore inference parameters from {resources_path}" + ) + + def denormalize(self, x: float, precision: int = 4) -> float: + """ + Denormalize from [0,1] scale to original scale. + + Args: + x: normalized value. + precision: optional rounding precision. Defaults to 4. + + Returns: + float: Value in regular scale. + """ + return round(x * (self.max_ - self.min_) + self.min_, precision) + + def normalize(self, x: float, precision: int = 3) -> float: + """ + Normalize from original scale to [0,1] scale. + + Args: + x: unnormalized input. + precision: optional rounding precision. + + Returns: + float: Normalized value. + """ + return round((x - self.min_) / (self.max_ - self.min_), precision) + + def validate_input(self, x: str) -> None: + + if self.tokenizer.expression_separator not in x: + raise ValueError( + f"Expression separator {self.tokenizer.expression_separator} not " + f"found in input {x}." + ) + if self.tokenizer.mask_token not in x: + raise ValueError( + f"Nothing to do, no mask to fill ({self.tokenizer.mask_token}) found" + f"in input {x}." + ) + if self.property not in x: + raise ValueError(f"No property token ({self.property}) found in input") + + text_sequence = x.split(self.tokenizer.expression_separator)[-1] + number_sequence = x[: -len(text_sequence) - 1] + if ( + self.tokenizer.mask_token in text_sequence + and self.tokenizer.mask_token in number_sequence + ): + raise ValueError( + f"Do not mask number and text sequence at the same time like in {x}." + ) + self.validate_input_molecule(text_sequence) + + def validate_input_molecule(self, sequence: str) -> None: + """ + Verifies that the non-numerical part of the input is a proper sequence. + + Args: + sequence: input sequence to be validated. + """ + raise NotImplementedError + + def safely_determine_task(self, x: str) -> str: + """ + Determines whether the passed sequence adheres to regression or generation task. + + Args: + x: the user-provided input sequence for the model, inluding mask tokens. + + Raises: + ValueError: if the sequence does not adhere to the formatting rules. + + Returns: + str: the task, either 'regression' or 'generation'. + """ + + self.validate_input(x) + if ( + self.tokenizer.mask_token + in x.split(self.tokenizer.expression_separator)[-1] + ): + return "generation" + + if x.count(self.tokenizer.mask_token) != self.property_mask_length: + raise ValueError( + f"To predict {self.property} you have to mask {self.property_mask_length} times" + ) + + return "regression" + + def generate_batch_regression(self, context: Sequence) -> List[Property]: + """ + Predict the property of a sample. + + Args: + context: a string with a masked property, a separator and an + entity. E.g. [MASK][MASK][MASK][MASK]|GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI. + + Returns: + List[Property]: a list of (denormalized) predicted properties for the entity. + """ + logger.info(f"Starting prediction for sequence {context}") + + # Prepare the batch + inputs = self.collator([self.tokenizer(context)]) + input_ids = inputs["input_ids"].cpu() + + # Forward pass + outputs = self.model(inputs) + + # Obtain the singular predictions + prediction = self.search(outputs["logits"].detach()) + + return self.compile_regression_result(input_ids, prediction) + + def compile_regression_result( + self, input_ids: torch.Tensor, prediction: torch.Tensor + ) -> List[Property]: + """ + Postprocesses the prediction from the property task to obtain a float. + + Args: + input_ids: 2D Tensor of shape (batch_size, sequence_length). + prediction: 2D Tensor of shape (batch_size, sequence_length). + + Returns: + List[Property]: list of property values. + """ + properties = [] + for inp, pred in zip(input_ids, prediction): + in_tokens = self.tokenizer.decode( + inp, clean_up_tokenization_spaces=False + ).split(" ") + out_tokens = self.tokenizer.decode( + pred, clean_up_tokenization_spaces=False + ).split(" ") + joined = self.tokenizer.get_sample_prediction(out_tokens, in_tokens) + _, gen_prop = self.tokenizer.aggregate_tokens(joined, label_mode=False) + properties.append(self.denormalize(gen_prop[self.property[1:-1]])) + return properties + + def generate_batch_generation(self, sequence: Sequence) -> Tuple: + """ + Conditionally generate sequences given a continuous property value and a fixed + sequence. This function first conditionally generates the novel sequences and + then predicts their properties using the RT again. Only if the predicted + property is within the tolerance range, the novel sequence is returned. + + Args: + sequence: the input sequence with masked tokens on the text. + + Returns: + Tuple[Tuple[str, float]]: a tuple of tuples, each containing the generated + sequence alongside its predicted property value. + """ + + # The property score has to be in the range [0, 1] + sequence = self.normalize_sequence(sequence) + + logger.warning(f"Starting prediction for sequence {sequence}") + + # Prepare the batch + tokens = self.tokenizer(sequence) + inputs = self.collator([tokens] * self.batch_size) + input_ids = inputs["input_ids"].clone() + # Forward pass + outputs = self.model(inputs) + # Obtain model predictions via the search method + predictions = self.search(outputs["logits"].detach()).squeeze() + # Combine predictions with the static part to obtain the full sequences + generations = input_ids + generations[generations == self.tokenizer.mask_token_id] = predictions[ + generations == self.tokenizer.mask_token_id + ] + + # Second part: Predict the properties of the just generated sequence + _input = self.property_collator.mask_tokens(generations) + prediction_input = { + "input_ids": _input[0], + "perm_mask": _input[1], + "target_mapping": _input[2], + "attention_mask": self.property_collator.attention_mask(generations), + } + # Pass through model + property_outputs = self.model(prediction_input) + # It's a design choice to go with greedy predictions here + predictions = torch.argmax(property_outputs["logits"].detach(), dim=-1) + # Obtain floating predictions + properties = self.compile_regression_result(generations, predictions) + # Obtain the sequences (AAS or SELFIES) + sequences = [ + self.tokenizer.to_readable( + "".join( + self.tokenizer.decode(seq, skip_special_tokens=True).split(" ") + ).split(self.tokenizer.expression_separator)[-1] + ) + for seq in generations + ] + successes: Tuple = tuple( + filter( + lambda x: abs(self.normalize(x[1]) - self.target_value) + < self.tolerance, + zip(sequences, properties), + ) + ) # type: ignore + logger.info(f"Successes: {successes}") + return successes + + def normalize_sequence(self, context: Sequence) -> Sequence: + """ + Take a sequence with a unnormalized property score and convert it to a + sequence with a normalized score. + + Args: + context: sequence with unnormalized property. + + Returns: + Sequence: sequence with normalized property. + """ + tokens = self.tokenizer.tokenize(context) + numerical_tokens = tokens[ + tokens.index(self.property) + + 1 : tokens.index(self.tokenizer.expression_separator) + ] + + unnorm = self.tokenizer.floating_tokens_to_float(numerical_tokens) + # Declard as class variable since used by other methods + self.target_value = self.normalize(unnorm) + norm = str(self.target_value)[: self.property_mask_length] + + tokens = ( + "".join(tokens[: tokens.index(self.property) + 1]) + + norm + + "".join(tokens[tokens.index(self.tokenizer.expression_separator) :]) + ) + return "".join(tokens) + + @staticmethod + def validate_numerical(sequences: List[Any]): + """ + Validate whether a list of sequences contains only numerical values. + + Args: + sequences: a list of hopefully only numerical values. + + Returns: + List[Any]: a tuple containing of the validated floats and valid indexes. + """ + + items = [item if isinstance(item, float) else None for item in sequences] + idxs = [i for i, item in enumerate(sequences) if isinstance(item, float)] + return items, idxs + + +class ChemicalLanguageRT(ConditionalGenerator): + """ + Hybrid regression and conditional molecular generation model as implemented in + https://arxiv.org/abs/2202.01338. It generates molecules with a desired solubility + (ESOL) score or predicts the ESOL of a given molecule. + For details on the ESOL task see: https://pubs.acs.org/doi/10.1021/ci034243x + + Attributes: + resources_path: path to the model. + context: user-specified input text for the model. + search: search key to instantiate a search via terminator.search.SEARCH_FACTORY. + temperature: the temperature parameter in case of a `sample` search. + batch_size: the batch size for the model, applicable only to generative task. + tolerance: the tolerance for the property of the generated molecules. + """ + + def __init__( + self, + resources_path: str, + context: str, + search: str = "sample", + temperature: float = 1.4, + batch_size: int = 8, + tolerance: float = 20.0, + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """ + Initialize the molecule generator. + + Args: + resources_path: directory where to find models and parameters. + search: search key to instantiate a search, defaults to `sample`. + temperature: temperature for the sampling. Defaults to 1.4. + batch_size: number of points sampled per call. Defaults to 8. + tolerance: the tolerance for the property of the generated molecules. + Given in percent. Defaults to 20. + device: device where the inference s running either as a dedicated class + or a string. If not provided is inferred. + """ + super().__init__(device=device, resources_path=resources_path) + + # Validate input and determine task + self.task = self.safely_determine_task(context) + + # Console outputs for usage of search methods + if search == "sample" and self.task == "regression": + logger.warning("For regression task, greedy search is recommended") + elif search == "greedy" and self.task == "generation": + logger.warning("For generation task, sample search is recommended") + + if search not in SEARCH_FACTORY.keys(): + raise KeyError(f"Pick a search of {SEARCH_FACTORY.keys()} not: {search}.") + self.search = SEARCH_FACTORY[search](temperature=temperature) + + # Based on task, set the correct generate_batch method + if self.task == "regression": + self.generate_batch = self.generate_batch_regression + elif self.task == "generation": + self.generate_batch = self.generate_batch_generation # type: ignore + self.batch_size = batch_size + self.property_collator = PropertyCollator( + tokenizer=self.tokenizer, + property_tokens=[self.property], + num_tokens_to_mask=[-1], + ignore_errors=False, + ) + # Tolerance on [0,1] scale + self.tolerance = tolerance / 100.0 + else: + raise ValueError(f"Unknown task: {self.task}") + + def validate_input_molecule(self, sequence: str) -> None: + """ + Verifies that the non-numerical part of the input sequence is a SELFIES. + + Args: + sequence: input sequence to be validated. + """ + # Fractional molecules based on non-masked parts of the SELFIES sequence + smis = list(map(decoder, sequence.split(self.tokenizer.mask_token))) + if -1 in smis: + raise ValueError(f"Invalid sequence: {sequence}") + + def validate_output(self, sequences: List[Any]) -> Tuple[List[Any], List[int]]: + """ + Validate the output of the RT model. + + Args: + sequences: list of sequences to be validated. + + Returns: + A tuple of validated items (Chem.rdchem.Mol in the case of a generation task + floating values otherwise) and a list of valid indexes. + """ + + if self.task == "regression": + return self.validate_numerical(sequences) + else: + # Convert SELFIES to SMILES + smiles_list = list( + filter(lambda x: x is not None, list(zip(*sequences))[0]) + ) + if smiles_list == []: + return ([None], [-1]) + return validate_molecules(smiles_list=smiles_list) # type: ignore + + +class ProteinLanguageRT(ConditionalGenerator): + """ + Hybrid regression and conditional protein generation model as implemented in + https://arxiv.org/abs/2202.01338. It generates peptides with a desired stability + score or predicts the stability score of a given molecule. + For details on the stability task see: https://doi.org/10.1126/science.aan0693 + + Attributes: + resources_path: path to the model. + context: user-specified input text for the model. + search: search key to instantiate a search via terminator.search.SEARCH_FACTORY. + temperature: the temperature parameter in case of a `sample` search. + batch_size: the batch size for the model, applicable only to generative task. + tolerance: the tolerance for the property of the generated molecules. + """ + + def __init__( + self, + resources_path: str, + context: str, + search: str = "sample", + temperature: float = 1.4, + batch_size: int = 32, + tolerance: float = 20.0, + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """ + Initialize the protein generator. + + Args: + resources_path: directory where to find models and parameters. + search: search key to instantiate a search, defaults to `sample`. + temperature: temperature for the sampling. Defaults to 1.4. + batch_size: number of points sampled per call. Defaults to 8. + tolerance: the tolerance for the property of the generated molecules. + Given in percent. Defaults to 20. + device: device where the inference s running either as a dedicated class + or a string. If not provided is inferred. + """ + super().__init__(device=device, resources_path=resources_path) + + # Validate input and determine task + self.task = self.safely_determine_task(context) + + # Console outputs for usage of search methods + if search == "sample" and self.task == "regression": + logger.warning("For regression task, greedy search is recommended") + elif search == "greedy" and self.task == "generation": + logger.warning("For generation task, sample search is recommended") + if search not in SEARCH_FACTORY.keys(): + raise KeyError(f"Pick a search of {SEARCH_FACTORY.keys()} not: {search}.") + self.search = SEARCH_FACTORY[search](temperature=temperature) + + # Based on task, set the correct generate_batch method + if self.task == "regression": + self.generate_batch = self.generate_batch_regression + elif self.task == "generation": + self.generate_batch = self.generate_batch_generation # type: ignore + self.batch_size = batch_size + self.property_collator = PropertyCollator( + tokenizer=self.tokenizer, + property_tokens=[self.property], + num_tokens_to_mask=[-1], + ignore_errors=False, + ) + # Tolerance on [0,1] scale + self.tolerance = tolerance / 100.0 + else: + raise ValueError(f"Unknown task: {self.task}") + + def validate_input_molecule(self, sequence: str) -> None: + """ + Verifies that the non-numerical part of the input sequence is a valid AAS. + + Args: + sequence: input sequence to be validated. + """ + if sequence != sequence.upper(): + raise ValueError( + f"Sequence {sequence} does not follow IUPAC convention for AAS" + ) + + def validate_output(self, sequences: Any) -> Tuple[List[Any], List[int]]: + """ + Validate the output of the RT model. + + Args: + sequences: list of sequences to be validated. + + Returns: + A tuple of validated items and a list of valid indexes. + """ + + if self.task == "regression": + return self.validate_numerical(sequences) + else: + items = [ + item + if ( + ( + item[0] == item[0].upper() + and self.tokenizer.mask_token not in item[0] + and not any([s.isdigit() for s in item[0]]) + ) + and isinstance(item[1], float) + ) + else None + for item in sequences + ] + idxs = [i for i, item in enumerate(sequences) if item in items] + return items, idxs diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/__init__.py b/src/gt4sd/algorithms/conditional_generation/reinvent/__init__.py new file mode 100644 index 000000000..56c0f7113 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/__init__.py @@ -0,0 +1,5 @@ +"""REINVENT initialization.""" + +from .core import Reinvent, ReinventGenerator + +__all__ = ["Reinvent", "ReinventGenerator"] diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/core.py b/src/gt4sd/algorithms/conditional_generation/reinvent/core.py new file mode 100644 index 000000000..101e70c1c --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/core.py @@ -0,0 +1,123 @@ +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import ReinventConditionalGenerator + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=Any) +Targeted = Callable[[T], Iterable[Any]] + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class Reinvent(GeneratorAlgorithm[S, T]): + """Reinvent sample generation algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate Reinvent ready to generate samples. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for predicting topics for a given text:: + + config = ReinventGenerator() + algorithm = Reinvent(configuration=config, target="") + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to perform the prediction via Reinvent's generator. + + Args: + configuration: helps to set up specific application of Reinvent. + target: context or condition for the generation. + + Returns: + callable with target generating samples. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: ReinventConditionalGenerator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_samples + + +@ApplicationsRegistry.register_algorithm_application(Reinvent) +class ReinventGenerator(AlgorithmConfiguration[str, str]): + """Configuration to generate molecules using the REINVENT algorithm. It generates the molecules minimizing the distances between the scaffolds.""" + + algorithm_name: ClassVar[str] = Reinvent.__name__ + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=20, + metadata=dict(description=("Number of samples to generate per scaffold")), + ) + + randomize: bool = field( + default=True, + metadata=dict(description=("Randomize the scaffolds if set to true")), + ) + + sample_uniquely: bool = field( + default=True, + metadata=dict(description=("Generate unique sample sequences if set to true")), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "SMILES for sample generation", + "description": "SMILES considered for the samples generation.", + "type": "string", + } + + def get_conditional_generator( + self, resources_path: str + ) -> ReinventConditionalGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_samples` method for targeted generation. + """ + return ReinventConditionalGenerator( + resources_path=resources_path, + batch_size=self.batch_size, + randomize=self.randomize, + sample_uniquely=self.sample_uniquely, + ) diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/implementation.py b/src/gt4sd/algorithms/conditional_generation/reinvent/implementation.py new file mode 100644 index 000000000..32750b65a --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/implementation.py @@ -0,0 +1,114 @@ +"""Implementation of Reinvent conditional generators.""" + +import logging +import os +from typing import List, NamedTuple, Optional, Set, Tuple + +from reinvent_models.lib_invent.models.model import DecoratorModel + +from .reinvent_core.core import ReinventBase, SampledSequencesDTO + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class SampledTuple(NamedTuple): + scaffold: str + decoration: str + nll: float + + +class ReinventConditionalGenerator(ReinventBase): + def __init__( + self, + resources_path: str, + batch_size: int, + randomize: bool, + sample_uniquely: bool, + ): + """Initialize Reinvent. + + Args: + resources_path: path where to load hypothesis, candidate labels and, optionally, the model. + batch_size: number of samples to generate per scaffold + randomize: randomize the scaffolds if set to true + sample_uniquely: generate unique sample sequences if set to true + """ + self.resources_path = resources_path + self.batch_size = batch_size + self.randomize = randomize + self.sample_uniquely = sample_uniquely + self.model_path = os.path.join(self.resources_path, "model.prior") + self.target: Optional[str] = None + + if not os.path.isfile(self.model_path): + logger.debug("reinvent model files does not exist locally") + raise OSError(f"artifacts file {self.model_path} does not exist locally") + + self.model = DecoratorModel.load_from_file(path=self.model_path) + super().__init__( + self.model, self.batch_size, self.randomize, self.sample_uniquely + ) + + def sample_unique_sequences(self, sampled_sequences: List[Tuple]) -> List[Tuple]: + """ + Samples the model for the given number of SMILES. + + Args: + scaffold_list: A list of SampledTuple. + Returns: + A list of SampledTuple. + """ + sequences = [ + SampledSequencesDTO(scaffold, decoration, nll) + for scaffold, decoration, nll in sampled_sequences + ] + logger.info("getting unique sample sequences from generated samples") + return [ + (sample.scaffold, sample.decoration, sample.nll) + for sample in self._sample_unique_sequences(sequences) + ] + + def generate_sampled_tuples(self, scaffold: str) -> Set[SampledTuple]: + """ + Samples the model for the given number of SMILES. + Args: + scaffold_list: A list of scaffold SMILES. + Returns: + A Set of SampledTuple. + """ + if self.target != scaffold: + self.target = scaffold + batch = next(iter(self.get_dataloader([scaffold]))) + logger.info("initialization of the dataloader") + scaffold_seqs, scaffold_seq_lengths = batch + self.scaffold_seqs = scaffold_seqs.expand( + self.batch_size - 1, scaffold_seqs.shape[1] + ) + self.scaffold_seq_lengths = scaffold_seq_lengths.expand(self.batch_size - 1) + logger.info("started generating samples with an nll score value") + sampled_sequences = list( + self.model.sample_decorations(self.scaffold_seqs, self.scaffold_seq_lengths) + ) + if self.sample_uniquely: + sampled_sequences = self.sample_unique_sequences(sampled_sequences) + + return set( + [ + SampledTuple(scaffold, decoration, nll) + for scaffold, decoration, nll in sampled_sequences + ] + ) + + def generate_samples(self, scaffold: str) -> Set[str]: + """ + Samples the model for the given number of SMILES. + + Args: + scaffold: A scaffold SMILES. + Returns: + A Set of SMILES representing molecules. + """ + return set( + [molecule for _, molecule, _ in self.generate_sampled_tuples(scaffold)] + ) diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/LICENSE b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/LICENSE new file mode 100644 index 000000000..ceddc8ea1 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 Atanas Patronov. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/README.md b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/README.md new file mode 100644 index 000000000..425d58d0e --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/README.md @@ -0,0 +1,7 @@ +# MolecularAI Reinvent Code Explanation + +The code for getting unique sample sequences, randomizing scaffolds, and generation of the dataset as well as the dataloader was taken from the original implementation of [Molecular Reinvent](https://github.com/MolecularAI/Reinvent) and can be found in the class [ReinventBase](/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py) which is in the subdirectory **reinvent_core**. + +We have a created a new function *get_dataloader* which is a modified version of the function *[run](https://github.com/MolecularAI/Reinvent/blob/982b26dd6cfeb8aa84b6d7e4a8c2a7edde2bad36/running_modes/lib_invent/rl_actions/sample_model.py#:~:text=def%20run(self%2C%20scaffold_list%3A%20List%5Bstr%5D)%20-%3E%20List%5BSampledSequencesDTO%5D%3A)* that returns an instance of the dataloader instead of the sampled sequences and it can be found in the [ReinventBase](/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py) class. + +Moreover, we have not included [BaseAction](https://github.com/MolecularAI/Reinvent/blob/982b26dd6cfeb8aa84b6d7e4a8c2a7edde2bad36/running_modes/lib_invent/rl_actions/sample_model.py#:~:text=class%20BaseAction(abc.ABC)%3A) as a parent class for the [ReinventBase](/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py) where we have added all the functions of [Molecular Reinvent](https://github.com/MolecularAI/Reinvent). \ No newline at end of file diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/__init__.py b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py new file mode 100644 index 000000000..d23521afa --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/reinvent/reinvent_core/core.py @@ -0,0 +1,125 @@ +"""MolecularAI Implementation of sample generation, randomizing scaffolds as well as fetching unique sample sequences + +The source of this file is +https://raw.githubusercontent.com/MolecularAI/Reinvent/982b26dd6cfeb8aa84b6d7e4a8c2a7edde2bad36/running_modes/lib_invent/rl_actions/sample_model.py +and it was only minimally changed. See README.md. +""" + +__copyright__ = "Copyright 2021, MolecularAI" +__license__ = "Apache 2.0" + +import logging +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch.utils.data as tud +from reinvent_chemistry import Conversions +from reinvent_chemistry.library_design import AttachmentPoints, BondMaker +from reinvent_chemistry.utils import get_indices_of_unique_smiles +from reinvent_models.lib_invent.models import dataset as md + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +@dataclass +class SampledSequencesDTO: + scaffold: str + decoration: str + nll: float + + +class ReinventBase: + def __init__( + self, model, batch_size: int, logger=None, randomize=False, sample_uniquely=True + ): + """ + Creates an instance of SampleModel. + :params model: A model instance (better in scaffold_decorating mode). + :params batch_size: Batch size to use. + :return: + """ + self.model = model + self._batch_size = batch_size + self._bond_maker = BondMaker() + self._attachment_points = AttachmentPoints() + self._randomize = randomize + self._conversions = Conversions() + self._sample_uniquely = sample_uniquely + + def get_dataloader(self, scaffold_list: List[str]) -> tud.DataLoader: + """ + Get a dataloader for the list of scaffolds to use with reinvent. + NOTE: This method was factored out of the `run` method from the original source. + :params scaffold_list: A list of scaffold SMILES. + :return: An instance of a torch dataloader. + """ + scaffold_list = ( + self._randomize_scaffolds(scaffold_list) + if self._randomize + else scaffold_list + ) + clean_scaffolds = [ + self._attachment_points.remove_attachment_point_numbers(scaffold) + for scaffold in scaffold_list + ] + dataset = md.Dataset( + clean_scaffolds, + self.model.vocabulary.scaffold_vocabulary, + self.model.vocabulary.scaffold_tokenizer, + ) + dataloader = tud.DataLoader( + dataset, + batch_size=len(dataset), + shuffle=False, + collate_fn=md.Dataset.collate_fn, + ) + return dataloader + + def run(self, scaffold_list: List[str]) -> List[SampledSequencesDTO]: + """ + Samples the model for the given number of SMILES. + NOTE: this method was slightly adapted from the original source. + :params scaffold_list: A list of scaffold SMILES. + :return: A list of SampledSequencesDTO. + """ + + dataloader = self.get_dataloader(scaffold_list) + + sampled_sequences = [] + for batch in dataloader: + + for _ in range(self._batch_size): + scaffold_seqs, scaffold_seq_lengths = batch + packed = self.model.sample_decorations( + scaffold_seqs, scaffold_seq_lengths + ) + for scaffold, decoration, nll in packed: + sampled_sequences.append( + SampledSequencesDTO(scaffold, decoration, nll) + ) + + if self._sample_uniquely: + sampled_sequences = self._sample_unique_sequences(sampled_sequences) + + return sampled_sequences + + def _sample_unique_sequences( + self, sampled_sequences: List[SampledSequencesDTO] + ) -> List[SampledSequencesDTO]: + strings = [ + "".join([ss.scaffold, ss.decoration]) + for index, ss in enumerate(sampled_sequences) + ] + unique_idxs = get_indices_of_unique_smiles(strings) + sampled_sequences_np = np.array(sampled_sequences) + unique_sampled_sequences = sampled_sequences_np[unique_idxs] + return unique_sampled_sequences.tolist() + + def _randomize_scaffolds(self, scaffolds: List[str]): + scaffold_mols = [ + self._conversions.smile_to_mol(scaffold) for scaffold in scaffolds + ] + randomized = [self._bond_maker.randomize_scaffold(mol) for mol in scaffold_mols] + return randomized diff --git a/src/gt4sd/algorithms/conditional_generation/template/__init__.py b/src/gt4sd/algorithms/conditional_generation/template/__init__.py new file mode 100644 index 000000000..084801e6f --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/template/__init__.py @@ -0,0 +1,8 @@ +"""Template initialization.""" + +from .core import Template, TemplateGenerator + +__all__ = [ + "Template", + "TemplateGenerator", +] diff --git a/src/gt4sd/algorithms/conditional_generation/template/core.py b/src/gt4sd/algorithms/conditional_generation/template/core.py new file mode 100644 index 000000000..701620d65 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/template/core.py @@ -0,0 +1,105 @@ +"""Template Algorithm""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from ...core import AlgorithmConfiguration, GeneratorAlgorithm # type: ignore +from ...registry import ApplicationsRegistry # type: ignore +from .implementation import Generator # type: ignore + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T") +S = TypeVar("S") +Targeted = Callable[[T], Iterable[Any]] + + +class Template(GeneratorAlgorithm[S, T]): + """Template Algorithm.""" + + def __init__( + self, configuration: AlgorithmConfiguration[S, T], target: Optional[T] = None + ): + """Template Generation + + Args: + configuration: domain and application + specification, defining types and validations. + target: Optional depending on the type of generative model. In this template + we will convert the target to a string. + + Example: + An example for using this template:: + + target = 'World' + configuration = TemplateGenerator() + algorithm = Template(configuration=configuration, target=target) + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to from generator. + + Args: + configuration: helps to set up the application. + target: context or condition for the generation. Just an optional string here. + + Returns: + callable generating a list of 1 item containing salutation and temperature converted to fahrenheit. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.hello_name # type:ignore + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(Template) +class TemplateGenerator(AlgorithmConfiguration[str, str]): + """Configuration for specific generator.""" + + algorithm_type: ClassVar[str] = "conditional_generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + temperature: int = field( + default=36, + metadata=dict(description="Temperature parameter ( in celsius )"), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + Returns: + target description. + """ + return { + "title": "Target name", + "description": "A simple string to define the name in the output [Hello name].", + "type": "string", + } + + def get_conditional_generator(self, resources_path: str) -> Generator: + return Generator(resources_path=resources_path, temperature=self.temperature) diff --git a/src/gt4sd/algorithms/conditional_generation/template/implementation.py b/src/gt4sd/algorithms/conditional_generation/template/implementation.py new file mode 100644 index 000000000..e438066fd --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/template/implementation.py @@ -0,0 +1,35 @@ +"""Implementation details for a Template algorithm""" + +import random +from typing import List + + +class Generator: + """Basic Generator for the template algorithm""" + + def __init__(self, resources_path: str, temperature: int) -> None: + """Initialize the Generator. + + Args: + resources_path: directory where to find models and parameters. + + """ + + self.resources_path = resources_path + self.temperature = temperature + + def hello_name( + self, + name: str, + ) -> List[str]: + """Validate a list of strings. + + Args: + name: a string. + + Returns: + a list containing salutation and temperature converted to fahrenheit. + """ + return [ + f"Hello {str(name)} {random.randint(1, int(1e6))} times and, fun fact, {str(self.temperature)} celsius equals to {(self.temperature * (9/5) + 32)} fahrenheit." + ] diff --git a/src/gt4sd/algorithms/conditional_generation/tests/__init__.py b/src/gt4sd/algorithms/conditional_generation/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_guacamol.py b/src/gt4sd/algorithms/conditional_generation/tests/test_guacamol.py new file mode 100644 index 000000000..09e4645ed --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_guacamol.py @@ -0,0 +1,230 @@ +"""Guacamol Baselines tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.conditional_generation.guacamol import ( + GraphGAGenerator, + GraphMCTSGenerator, + GuacaMolGenerator, + SMILESGAGenerator, + SMILESLSTMHCGenerator, + SMILESLSTMPPOGenerator, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry +from gt4sd.tests.utils import GT4SDTestSettings + +TARGET = {"isomer_scorer": {"target": 5.0, "target_smile": "NCCCCC"}} +algorithm_parameters = { + "smiles_ga": {"random_start": True}, + "graph_ga": {"random_start": True}, + "graph_mcts": { + "init_smiles": "CC", + "population_size": 5, + "generations": 5, + "num_sims": 10, + "max_children": 5, + "max_atoms": 10, + }, + "smiles_lstm_hc": { + "random_start": True, + "mols_to_sample": 10, + "keep_top": 5, + "max_len": 2, + "optimize_batch_size": 3, + "n_epochs": 2, + }, + "smiles_lstm_ppo": {"num_epochs": 2, "episode_size": 10, "optimize_batch_size": 2}, +} + +test_settings = GT4SDTestSettings.get_instance() + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + SMILESGAGenerator, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + ), + ( + GraphGAGenerator, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + ), + ( + GraphMCTSGenerator, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + ), + ( + SMILESLSTMHCGenerator, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + ), + ( + SMILESLSTMPPOGenerator, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (SMILESGAGenerator), + (GraphGAGenerator), + (GraphMCTSGenerator), + (SMILESLSTMHCGenerator), + (SMILESLSTMPPOGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (SMILESGAGenerator), + (GraphGAGenerator), + (GraphMCTSGenerator), + (SMILESLSTMHCGenerator), + (SMILESLSTMPPOGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm, algorithm_parameters", + [ + pytest.param( + SMILESGAGenerator, + GuacaMolGenerator, + algorithm_parameters["smiles_ga"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + (GraphGAGenerator, GuacaMolGenerator, algorithm_parameters["graph_ga"]), + (GraphMCTSGenerator, GuacaMolGenerator, algorithm_parameters["graph_mcts"]), + pytest.param( + SMILESLSTMHCGenerator, + GuacaMolGenerator, + algorithm_parameters["smiles_lstm_hc"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + pytest.param( + SMILESLSTMPPOGenerator, + GuacaMolGenerator, + algorithm_parameters["smiles_lstm_ppo"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + ], +) +def test_generation_via_import(config, algorithm, algorithm_parameters): + parameters = { + "batch_size": 1, + } + for param, value in algorithm_parameters.items(): + parameters[param] = value + config = config(**parameters) + algorithm = algorithm(configuration=config, target=TARGET) + items = list(algorithm.sample(1)) + assert len(items) == 1 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name, algorithm_parameters", + [ + pytest.param( + SMILESGAGenerator.__name__, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + algorithm_parameters["smiles_ga"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + ( + GraphGAGenerator.__name__, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + algorithm_parameters["graph_ga"], + ), + ( + GraphMCTSGenerator.__name__, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + algorithm_parameters["graph_mcts"], + ), + pytest.param( + SMILESLSTMHCGenerator.__name__, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + algorithm_parameters["smiles_lstm_hc"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + pytest.param( + SMILESLSTMPPOGenerator.__name__, + "conditional_generation", + "materials", + GuacaMolGenerator.__name__, + algorithm_parameters["smiles_lstm_ppo"], + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + ], +) +def test_generation_via_registry( + algorithm_type, + domain, + algorithm_name, + algorithm_application, + algorithm_parameters, +): + parameters = { + "target": TARGET, + "algorithm_type": algorithm_type, + "domain": domain, + "algorithm_name": algorithm_name, + "algorithm_application": algorithm_application, + "batch_size": 1, + } + for param, value in algorithm_parameters.items(): + parameters[param] = value + algorithm = ApplicationsRegistry.get_application_instance(**parameters) + items = list(algorithm.sample(1)) + assert len(items) == 1 diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_key_bert.py b/src/gt4sd/algorithms/conditional_generation/tests/test_key_bert.py new file mode 100644 index 000000000..7aa38905f --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_key_bert.py @@ -0,0 +1,105 @@ +"""KeyBERT tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.conditional_generation.key_bert import ( + KeyBERTGenerator, + KeywordBERTGenerationAlgorithm, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + KeyBERTGenerator, + "conditional_generation", + "nlp", + KeywordBERTGenerationAlgorithm.__name__, + ) + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (KeyBERTGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (KeyBERTGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert len(versions) > 0 + + +@pytest.mark.parametrize( + "config, algorithm", + [ + (KeyBERTGenerator, KeywordBERTGenerationAlgorithm), + ], +) +def test_generation_via_import(config, algorithm): + algorithm = algorithm( + configuration=config(), target="This is a text used for the tests." + ) + items = list(algorithm.sample(1)) + assert len(items) == 1 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + ( + KeyBERTGenerator.__name__, + "conditional_generation", + "nlp", + KeywordBERTGenerationAlgorithm.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target="This is a text used for the tests.", + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + items = list(algorithm.sample(1)) + assert len(items) == 1 diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_moses.py b/src/gt4sd/algorithms/conditional_generation/tests/test_moses.py new file mode 100644 index 000000000..1a43f2f39 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_moses.py @@ -0,0 +1,127 @@ +"""Moses tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.conditional_generation.guacamol import ( + AaeGenerator, + MosesGenerator, + OrganGenerator, + VaeGenerator, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + AaeGenerator, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ( + VaeGenerator, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ( + OrganGenerator, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [(AaeGenerator), (VaeGenerator), (OrganGenerator)], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [(AaeGenerator), (VaeGenerator), (OrganGenerator)], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm", + [ + (AaeGenerator, MosesGenerator), + (VaeGenerator, MosesGenerator), + (OrganGenerator, MosesGenerator), + ], +) +def test_generation_via_import(config, algorithm): + config = config() + algorithm = algorithm(configuration=config, target="") + items = list(algorithm.sample(2)) + assert len(items) == 2 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + ( + AaeGenerator.__name__, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ( + VaeGenerator.__name__, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ( + OrganGenerator.__name__, + "conditional_generation", + "materials", + MosesGenerator.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + items = list(algorithm.sample(5)) + assert len(items) == 5 diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_paccmann_rl.py b/src/gt4sd/algorithms/conditional_generation/tests/test_paccmann_rl.py new file mode 100644 index 000000000..fbbaa931c --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_paccmann_rl.py @@ -0,0 +1,162 @@ +"""PaccMannRL tests.""" + +import pickle +from typing import ClassVar, Type + +import numpy as np +import pytest + +from gt4sd.algorithms.conditional_generation.paccmann_rl import ( + PaccMannRL, + PaccMannRLOmicBasedGenerator, + PaccMannRLProteinBasedGenerator, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + PaccMannRLProteinBasedGenerator, + "conditional_generation", + "materials", + PaccMannRL.__name__, + ), + ( + PaccMannRLOmicBasedGenerator, + "conditional_generation", + "materials", + PaccMannRL.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (PaccMannRLProteinBasedGenerator), + (PaccMannRLOmicBasedGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (PaccMannRLProteinBasedGenerator), + (PaccMannRLOmicBasedGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, example_target, algorithm", + [ + ( + PaccMannRLProteinBasedGenerator, + "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT", + PaccMannRL, + ), + ( + PaccMannRLOmicBasedGenerator, + np.random.rand(2128), + PaccMannRL, + ), + ( + PaccMannRLOmicBasedGenerator, + f"[{','.join(map(str, np.random.rand(2128)))}]", + PaccMannRL, + ), + ], +) +def test_generation_via_import(config, example_target, algorithm): + paccmann_rl = algorithm(configuration=config(), target=example_target) + items = list(paccmann_rl.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "algorithm_application, target", + [ + ( + PaccMannRLProteinBasedGenerator.__name__, + "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT", + ), + ], +) +def test_generation_via_registry(target, algorithm_application): + paccmann_rl = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type="conditional_generation", + domain="materials", + algorithm_name=PaccMannRL.__name__, + algorithm_application=algorithm_application, + generated_length=5, + ) + items = list(paccmann_rl.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "config_class", + [ + (PaccMannRLProteinBasedGenerator), + (PaccMannRLOmicBasedGenerator), + ], +) +def test_configuration_pickable(config_class: Type[AlgorithmConfiguration]): + # implementation + obj = config_class(algorithm_version="test") + + # --- + import inspect + + inspect.getmodule(config_class) + # --- + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj + + # registered + Config = ApplicationsRegistry.get_application( + algorithm_type="conditional_generation", + domain="materials", + algorithm_name=PaccMannRL.__name__, + algorithm_application=config_class.__name__, + ).configuration_class + + obj = Config(algorithm_version="test") + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_regression_transformer.py b/src/gt4sd/algorithms/conditional_generation/tests/test_regression_transformer.py new file mode 100644 index 000000000..4b6fc3949 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_regression_transformer.py @@ -0,0 +1,190 @@ +"""RegressionTransformer tests.""" + +import pickle +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.conditional_generation.regression_transformer import ( + RegressionTransformer, + RegressionTransformerMolecules, + RegressionTransformerProteins, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + RegressionTransformerMolecules, + "conditional_generation", + "materials", + RegressionTransformer.__name__, + ), + ( + RegressionTransformerProteins, + "conditional_generation", + "materials", + RegressionTransformer.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (RegressionTransformerMolecules), + (RegressionTransformerProteins), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (RegressionTransformerMolecules), + (RegressionTransformerProteins), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, example_target, algorithm, params", + [ + ( + RegressionTransformerMolecules, + "[MASK][MASK][MASK][MASK][MASK]|[Cl][C][Branch1_2][Branch1_2][=C][Branch1_1][C][Cl][Cl][Cl]", + RegressionTransformer, + {"search": "greedy", "num_samples": 1}, + ), + ( + RegressionTransformerMolecules, + "-3.499|[C][C][MASK][MASK][MASK][C][Br]", + RegressionTransformer, + {"search": "sample", "temperature": 2.0, "num_samples": 5}, + ), + ( + RegressionTransformerProteins, + "[MASK][MASK][MASK][MASK][MASK]|GSQEVNSNASPEEAEIARKAGATTWTEKGNKWEIRI", + RegressionTransformer, + {"search": "greedy", "num_samples": 1}, + ), + ( + RegressionTransformerProteins, + "1.1234|TTIKNG[MASK][MASK][MASK]YTVPLSPEQAAK[MASK][MASK][MASK]KKRWPDYEVQIHGNTVKVT", + RegressionTransformer, + {"search": "sample", "temperature": 2.0, "num_samples": 5}, + ), + ], +) +def test_generation_via_import(config, example_target, algorithm, params): + num_samples = params.pop("num_samples", 1) + regression_transformer = algorithm( + configuration=config(**params), target=example_target + ) + items = list(regression_transformer.sample(num_samples)) + assert len(items) == num_samples + + +@pytest.mark.parametrize( + "algorithm_application, target, params", + [ + ( + RegressionTransformerMolecules.__name__, + "[MASK][MASK][MASK][MASK][MASK]|[Cl][C][Branch1_2][Branch1_2][=C][Branch1_1][C][Cl][Cl][Cl]", + {"search": "greedy", "num_samples": 1}, + ), + ( + RegressionTransformerMolecules.__name__, + "-3.499|[C][C][MASK][MASK][MASK][C][Br]", + {"search": "sample", "temperature": 2.0, "num_samples": 5}, + ), + ( + RegressionTransformerProteins.__name__, + "[MASK][MASK][MASK][MASK][MASK]|GSQEVNSNASPEEAEIARKAGATTWTEKGNKWEIRI", + {"search": "greedy", "num_samples": 1}, + ), + ( + RegressionTransformerProteins.__name__, + "1.1234|TTIKNG[MASK][MASK][MASK]YTVPLSPEQAAK[MASK][MASK][MASK]KKRWPDYEVQIHGNTVKVT", + {"search": "sample", "temperature": 2.0, "num_samples": 5}, + ), + ], +) +def test_generation_via_registry(target, algorithm_application, params): + num_samples = params.pop("num_samples", 1) + regression_transformer = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type="conditional_generation", + domain="materials", + algorithm_name=RegressionTransformer.__name__, + algorithm_application=algorithm_application, + **params, + ) + items = list(regression_transformer.sample(num_samples)) + assert len(items) == num_samples + + +@pytest.mark.parametrize( + "config_class", + [ + (RegressionTransformerMolecules), + (RegressionTransformerProteins), + ], +) +def test_configuration_pickable(config_class: Type[AlgorithmConfiguration]): + # implementation + obj = config_class(algorithm_version="test") + + # --- + import inspect + + inspect.getmodule(config_class) + # --- + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj + + # registered + Config = ApplicationsRegistry.get_application( + algorithm_type="conditional_generation", + domain="materials", + algorithm_name=RegressionTransformer.__name__, + algorithm_application=config_class.__name__, + ).configuration_class + + obj = Config(algorithm_version="test") + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj diff --git a/src/gt4sd/algorithms/conditional_generation/tests/test_reinvent.py b/src/gt4sd/algorithms/conditional_generation/tests/test_reinvent.py new file mode 100644 index 000000000..0a2b08f47 --- /dev/null +++ b/src/gt4sd/algorithms/conditional_generation/tests/test_reinvent.py @@ -0,0 +1,104 @@ +"""Reinvent tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.conditional_generation.reinvent import Reinvent, ReinventGenerator +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + ReinventGenerator, + "conditional_generation", + "materials", + Reinvent.__name__, + ) + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (ReinventGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (ReinventGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm", + [ + (ReinventGenerator, Reinvent), + ], +) +def test_generation_via_import(config, algorithm): + config = config(batch_size=10) + algorithm = algorithm(configuration=config, target="N1CCN(CC1)CCCCN") + items = list(algorithm.sample(1)) + assert len(items) == 1 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + ( + ReinventGenerator.__name__, + "conditional_generation", + "materials", + Reinvent.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target="N1CCN(CC1)CCCCN", + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + randomize=True, + sample_uniquely=True, + batch_size=5, + ) + items = list(algorithm.sample(5)) + assert len(items) == 5 diff --git a/src/gt4sd/algorithms/controlled_sampling/__init__.py b/src/gt4sd/algorithms/controlled_sampling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/__init__.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/__init__.py new file mode 100644 index 000000000..6a42c2446 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/__init__.py @@ -0,0 +1,5 @@ +"""AdvancedManufacturing initialization.""" + +from .core import AdvancedManufacturing, CatalystGenerator + +__all__ = ["AdvancedManufacturing", "CatalystGenerator"] diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/core.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/core.py new file mode 100644 index 000000000..56a6abe67 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/core.py @@ -0,0 +1,135 @@ +"""Advanced manufacturing algorithms.""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from ....domains.materials import SmallMolecule +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation.core import Generator +from .implementation.nccr import CatalystGenerator as NCCRCatalystGenerator + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=Any) +Targeted = Callable[[T], Iterable[Any]] + + +class AdvancedManufacturing(GeneratorAlgorithm[S, T]): + """Advance manufacturing generator algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate AdvancedManufacturing ready to generate items. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating small molecules (SMILES) with a target binding energy:: + + config = CatalystGenerator() + algorithm = AdvancedManufacturing(configuration=config, target=10.0) + items = list(algorithm.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to sample batches via AdvancedManufacturing's generator. + + Args: + configuration: helps to set up specific application of AdvancedManufacturing. + target: context or condition for the generation. + + Returns: + callable with target generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_samples + + +@ApplicationsRegistry.register_algorithm_application(AdvancedManufacturing) +class CatalystGenerator(AlgorithmConfiguration[SmallMolecule, float]): + """Configuration to generate catalysts with a desired binding energy.""" + + algorithm_type: ClassVar[str] = "controlled_sampling" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + number_of_points: int = field( + default=32, + metadata=dict( + description="Number of points to sample with the Gaussian Process." + ), + ) + number_of_steps: int = field( + default=50, + metadata=dict( + description="Number of optimization steps in the Gaussian Process optimization." + ), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + primer_smiles: str = field( + default="", + metadata=dict( + description="Primer molecule to initiate the sampling in SMILES format. Defaults to no primer." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Target energy", + "description": "Binding energy target for the catalysts generated.", + "type": "number", + } + + def get_conditional_generator(self, resources_path: str) -> Generator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return NCCRCatalystGenerator( + resources_path=resources_path, + generated_length=self.generated_length, + number_of_points=self.number_of_points, + number_of_steps=self.number_of_steps, + primer_smiles=self.primer_smiles, + ) diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/__init__.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/core.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/core.py new file mode 100644 index 000000000..39311aba8 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/core.py @@ -0,0 +1,495 @@ +"""Controlled sampling of concatenated encodings via Gaussian Process.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import numpy as np +import pandas as pd +import torch +from skopt import gp_minimize +from skopt.space import Real + +from .....frameworks.torch import device_claim + +Point = Union[np.ndarray, torch.Tensor, pd.Series] +MINIMUM_REAL = np.finfo(dtype=np.float32).min +MAXIMUM_REAL = np.finfo(dtype=np.float32).max + + +def point_to_tensor(point: Point) -> torch.Tensor: + """Convert point to tensor. + + Args: + point: a point. + + Returns: + tensor representing a point. + """ + return ( + point.clone().detach().float() + if isinstance(point, torch.Tensor) + else torch.tensor(point).float() + ) + + +@dataclass +class Representation: + """ + A generic representation for a composition problem. + + Attributes: + model: a torch module for decoding. + z_dimension: dimension of the latent space. + fixed_representation: fixed representation in the latent space. + z_index: slice for indexing a point to represent the latent space. + """ + + model: torch.nn.Module + z_dimension: int + fixed_representation: Optional[torch.Tensor] = None + z_index: Optional[slice] = None + + def decode(self, z: Point) -> Any: + """Decode the representation from the latent space. + + Args: + z: a point in the latent space. + + Returns: + the decoded representation. + """ + z = torch.unsqueeze(point_to_tensor(z), dim=0) + reconstructed = self.model.decode(z) # type: ignore + return reconstructed + + def deserialize( + self, filepath: str, device: Optional[Union[torch.device, str]] = None + ) -> None: + """ + Deserialize a representation from file. + + Args: + filepath: path to the serialized represantation. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + + Returns: + the representation object. + """ + device = device_claim(device) + weights = torch.load(filepath, map_location=device) + self.model = self.model.to(device) # type:ignore + self.model.load_state_dict(weights) + + +class PropertyPredictor: + """Property prediction class. + + Attributes: + input_representations: order of the input representations. + """ + + input_representations: Optional[List[str]] = None + + def __call__(self, z: Point) -> float: + """Call the property predictor on the point. + + Args: + z: the point. + + Returns: + the predicted property. + """ + raise NotImplementedError("Propery prediction not implemented") + + +class Scaler: + """Scaler class.""" + + def __call__(self, example: Any) -> Any: + """Scale the example appropriately. + + Args: + example: an example prior to scaling. + + Returns: + the scaled example. + """ + raise NotImplementedError("Scaler not implemented not implemented") + + +RepresentationsDict = Dict[str, Representation] + + +class Objective: + """ + Objective function for representations. + """ + + def __init__( + self, + targets: Dict[str, float], + property_predictors: Dict[str, PropertyPredictor], + representations: RepresentationsDict, + representation_order: List[str] = None, + scalers: Optional[Dict[str, Scaler]] = None, + weights: Optional[Dict[str, float]] = None, + custom_score_function: Optional[ + Callable[[Point, RepresentationsDict, Optional[Dict[str, Scaler]]], float] + ] = None, + custom_score_weight: float = 1.0, + minimize: bool = True, + ): + """Constructs an objective function. + + Args: + targets: a dictionary of target values. + property_predictors: a dictionary of target property predictors. + representations: a dictionary of decodeable representations. + representation_order: order of the representations. Defaults to None, a.k.a., lexicographic order. + scalers: scalers for represantation features. Defaults to None, a.k.a., no scaling. + weights: weights for each the target. Defaults to None, a.k.a., targets evenly weigthed. + custom_score_function: a custom score function to apply on decoded representations. Defaults to None, a.k.a., no custom score. + custom_score_weight: weight for the custom score. Defaults to 1.0. + minimize: whether the objective needs to be minimized. Defaults to True. + """ + self.targets = targets + self.property_predictors = property_predictors + self.representations = representations + if representation_order is None: + self.representation_order = sorted(list(representations.keys())) + else: + self.representation_order = representation_order + self.scalers = scalers + self.weights = weights + self.custom_score_function = custom_score_function + self.custom_score_weight = custom_score_weight + self.minimize = minimize + + if self.weights is None: + weights_dictionary = dict() + for target in self.targets.keys(): + weights_dictionary[target] = 1.0 + self.weights = weights_dictionary + + def construct_property_representation( + self, z: torch.Tensor, property_name: str + ) -> torch.Tensor: + """Construct a representation for a specific property. + + The encoded point and fixed encodings (or slices thereof) if available + are concatenated in the right order. + + Todo: + Check explanation and improve it. + + Args: + z: the point. + property_name: name of the property for which to construct the representation. + + Returns: + representation for which a specific property can be predicted. + """ + # TODO make generic for self.representations: RepresentationsDict + # defer this to the configuration, or some other place that defines how representations belong together + propery_predictor = self.property_predictors[property_name] + if propery_predictor.input_representations: + representation_names = propery_predictor.input_representations + else: + representation_names = self.representation_order + z_list = [] + for representation_name in representation_names: + representation = self.representations[representation_name] + if representation.fixed_representation: + z_list.append(representation.fixed_representation) + else: + z_list.append(z[representation.z_index]) + z_latent = torch.cat(z_list) + return z_latent.reshape(-1, z_latent.shape[0]) + + def evaluate(self, z: Point) -> float: + """Evaluate objective function for a point in latent space. + + Args: + z: the point. + + Returns: + the score of the point. + """ + z = point_to_tensor(z) + # predict all properties + predicted_properties = dict() + with torch.no_grad(): + for property_name in self.targets.keys(): + property_predictor = self.property_predictors[property_name] + latent_z = self.construct_property_representation( + z=z, property_name=property_name + ) + predicted_properties[property_name] = property_predictor(latent_z) + + # TODO aggregate the following scores over different properties, + # or really only keep last one? + if self.custom_score_function: + custom_score = self.custom_score_function( + latent_z, + self.representations, + self.scalers, + ) + else: + custom_score = 0 + + scores = [] + for property_name, predicted_property in predicted_properties.items(): + scores.append(abs(self.targets[property_name] - predicted_property)) + score = sum(scores) + + # this is to penalize `custom_score` the non normalization + total_score = score + self.custom_score_weight * custom_score + + if not self.minimize: + score = -1 * total_score + return total_score + + +class GaussianProcessRepresentationsSampler: + def __init__( + self, + targets: Dict[str, float], + property_predictors: Dict[str, PropertyPredictor], + representations: RepresentationsDict, + representation_order: Optional[List[str]] = None, + bounds: Optional[ + Dict[str, Union[List[Tuple[float, float]], Tuple[float, float]]] + ] = None, + # TODO Any should be type of scaler; default to lambda returning 0? + scalers: Optional[Dict[str, Scaler]] = None, + weights: Optional[Dict[str, float]] = None, + custom_score_function: Optional[ + Callable[[Point, RepresentationsDict, Optional[Dict[str, Scaler]]], float] + ] = None, + custom_score_weight: float = 1.0, + minimize: bool = True, + random_state: int = 42, + random_starts: int = 10, + ): + """Constucts a GaussianProcessRepresentationsSampler. + + Args: + targets: a dictionary of target values. + property_predictors: a dictionary of target property predictors. + representations: a dictionary of decodeable representations. + representation_order: order of the representations. Defaults to None, a.k.a., lexicographic order. + bounds: bounds for the optmization. Defaults to None, a.k.a., unbounded. + scalers: scalers for represantation features. Defaults to None, a.k.a., no scaling. + weights: weights for each the target. Defaults to None, a.k.a., targets evenly weigthed. + custom_score_function: a custom score function to apply on decoded representations. Defaults to None, a.k.a., no custom score. + custom_score_weight: weight for the custom score. Defaults to 1.0. + minimize: whether the objective needs to be minimized. Defaults to True. + random_state: random state. Defaults to 42. + random_starts: number of random restarts. Defaults to 10. + """ + self.targets = targets + self.property_predictors = property_predictors + self.representations = representations + self.representation_order = representation_order + if self.representation_order is None: + self.representation_order = sorted(list(representations.keys())) + self.scalers = scalers + self.weigths = weights + self.custom_score_function = custom_score_function + self.custom_score_weight = custom_score_weight + self.minimize = minimize + self.random_state = random_state + self.random_starts = random_starts + self.set_bounds(bounds) + self.dimensions = self.define_dimensions(self.representation_order) + + def _get_bounds( + self, minimum_value: float, maximum_value: float, z_dimension: int + ) -> List[Tuple[float, float]]: + """ + Define a list of bounds for an hypercube. + + Args: + minimum_value: minimum value. + maximum_value: maximum value. + z_dimension: dimension of the hypercube. + + Returns: + the list of bounds. + """ + return [(minimum_value, maximum_value) for _ in range(z_dimension)] + + def set_bounds( + self, + bounds: Optional[ + Dict[str, Union[List[Tuple[float, float]], Tuple[float, float]]] + ] = None, + ) -> None: + """Set the bounds for the optimization. + + Args: + bounds: bounds for the optmization. Defaults to None, a.k.a., unbounded. + """ + self.bounds = bounds if bounds else dict() + for representation_name, bounds in self.bounds.items(): # type:ignore + z_dimension = self.representations[representation_name].z_dimension + if isinstance(bounds, tuple): + self.bounds[representation_name] = self._get_bounds( + bounds[0], bounds[1], z_dimension + ) + else: + self.bounds[representation_name] = bounds # type:ignore + for representation_name in self.representations.keys() - self.bounds.keys(): + z_dimension = self.representations[representation_name].z_dimension + self.bounds[representation_name] = self._get_bounds( + MINIMUM_REAL, MAXIMUM_REAL, z_dimension + ) + + def define_dimensions(self, representation_order: List[str]) -> List[Real]: + """Define the dimensions of the optimization space. + + Args: + representation_order: order of the representations. + + Returns: + a list of dimensions. + """ + dimensions = [] + latent_index = 0 + for representation_name in representation_order: + representation = self.representations[representation_name] + representation_bounds = self.bounds[representation_name] + representation.z_index = slice( + latent_index, latent_index + representation.z_dimension + ) + latent_index += representation.z_dimension + dimensions.extend( + [ + Real(lower_bound, upper_bound) # type:ignore + for lower_bound, upper_bound in representation_bounds + ] + ) + return dimensions + + def optimize( + self, + targets: Optional[Dict[str, float]] = None, + relevant_representations: Optional[List[str]] = None, + representation_order: Optional[List[str]] = None, + z0: Optional[Point] = None, + number_of_points: int = 1, + number_of_steps: int = 50, + random_state: int = 42, + verbose: bool = False, + weights: Optional[Dict[str, float]] = None, + acquisition_method: str = "PI", + ) -> List[Dict[str, Any]]: + """Run the optimization. + + Args: + targets: a dictionary of target values. Defaults to None, a.k.a., use the ones defined at construction time. + relevant_representations: list of relevant representations to be optimized. Defaults to None, a.k.a., inferred from non fixed representations. + representation_order: order of the representations. Defaults to None, a.k.a., use the one defined at construction time. + z0: the starting point for the optimization. Defaults to None, a.k.a., perform random starts. + number_of_points: number of optimal points to return. Defaults to 1. + number_of_steps: number of optimization steps. Defaults to 50. + random_state: random state. Defaults to 42. + verbose: control verbosity. Defaults to False. + weights: weights for each the target. Defaults to None, a.k.a., use the ones defined at construction time. + acquisition_method: acquisition method to use in the Gaussian Process optimization. Defaults to "PI". More details at: https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html. + + Raises: + NotImplementedError: invalid acquisition function. + + Returns: + list of orderd optimal points with decoded relevant representations. + """ + if representation_order is None: + representation_order = self.representation_order + dimensions = self.define_dimensions(cast(List[str], representation_order)) + + np.random.seed(random_state) + + if acquisition_method not in ["PI", "LCB", "EI", "gp_hedge", "EIps", "PIps"]: + raise NotImplementedError("Give valid acquisition function") + + if targets is None: + targets = self.targets + + if weights is None: + weights = self.weigths + + objective = Objective( + targets=targets, + property_predictors=self.property_predictors, + representations=self.representations, + representation_order=representation_order, + scalers=self.scalers, + weights=weights, + custom_score_function=self.custom_score_function, + custom_score_weight=self.custom_score_weight, + minimize=self.minimize, + ) + + y0 = None + random_starts: Optional[int] + if z0 is None: + random_starts = self.random_starts + else: + random_starts = None + + gaussian_process_results = gp_minimize( + objective.evaluate, + dimensions, + n_calls=number_of_steps, + n_random_starts=random_starts, + x0=z0, + y0=y0, + acq_func=acquisition_method, + random_state=np.random.RandomState(random_state), + verbose=verbose, + ) + + objective_values, points = zip( + *sorted( + zip( + gaussian_process_results.func_vals, gaussian_process_results.x_iters + ) + ) + ) + + results_list = [] + + if relevant_representations is None: + relevant_representations = [ + representation_name + for representation_name, representation in self.representations.items() + if representation.fixed_representation is None + ] + + for point, objective_value in zip( + points[:number_of_points], objective_values[:number_of_points] + ): + optimization_result = {"objective": objective_value, "z": point} + for representation_name in relevant_representations: + representation = self.representations[representation_name] + z = point[representation.z_index] + optimization_result[representation_name] = representation.decode(z) + + results_list.append(optimization_result) + + return results_list + + +class Generator: + def generate_samples(self, target: Any) -> List[Any]: + """Generate samples. + + Args: + target: target for generation. + + Returns: + samples generated. + """ + raise NotImplementedError("Generate samples not implemented.") diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/__init__.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/__init__.py new file mode 100644 index 000000000..d0e3e2739 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/__init__.py @@ -0,0 +1,7 @@ +"""NCCR module initialization.""" + +from .core import ( # noqa: F401 + CatalystBindingEnergyPredictor, + CatalystGenerator, + CatalystVAE, +) diff --git a/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/core.py b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/core.py new file mode 100644 index 000000000..528a223a7 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/advanced_manufacturing/implementation/nccr/core.py @@ -0,0 +1,222 @@ +"""Catalyst design for NCCR project.""" + +import logging +import os +import re +from typing import List, Union, cast + +import torch + +from ......frameworks.granular.ml.models import ( + GranularEncoderDecoderModel, + MlpPredictor, +) +from ......frameworks.granular.ml.module import GranularModule +from ......frameworks.granular.tokenizer.tokenizer import SmilesTokenizer +from ..core import ( + GaussianProcessRepresentationsSampler, + Generator, + Point, + PropertyPredictor, + Representation, + point_to_tensor, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class CatalystVAE(Representation): + """Catalyst VAE for suzuki reactions.""" + + model: GranularEncoderDecoderModel + + def __init__( + self, + resources_path: str, + padding_length: int = 127, + maximum_length: int = 100, + primer_smiles: str = "", + ) -> None: + """Constructs a CatalystVAE. + + Args: + resources_path: directory where to find models and configurations. + pading_length: size of the padded sequence. Defaults to 127. + maximum_length: maximum length of the synthesis. + primer_smiles: primer SMILES representation. Default to "", a.k.a., no primer. + """ + self.vocabulary_filepath = os.path.join(resources_path, "vocab_combined.csv") + self.checkpoint_filepath = os.path.join( + resources_path, "epoch=199-step=5799.ckpt" + ) + self.tokenizer = SmilesTokenizer(self.vocabulary_filepath) + self.model = cast( + GranularEncoderDecoderModel, + GranularModule.load_from_checkpoint(self.checkpoint_filepath).autoencoders[ + 0 + ], + ) + self.model.eval() + self.padding_length = padding_length + self.z_dimension = self.model.latent_size + self.maximum_length = maximum_length + self.primer_smiles = primer_smiles + if len(self.primer_smiles) > 0: + self.primer_point = self.smiles_to_latent(self.primer_smiles) + else: + self.primer_point = torch.zeros(1, self.z_dimension) + self.clean_regex = re.compile( + r"{}|{}".format(self.tokenizer.sos_token, self.tokenizer.unk_token) + ) + self.end_regex = re.compile( + r"{}|{}".format(self.tokenizer.eos_token, self.tokenizer.pad_token) + ) + + def smiles_to_latent(self, smiles: str) -> Point: + """Encode a SMILES into a latent point. + + Args: + smiles: a SMILES representation of a molecule. + + Returns: + the encoded latent space point. + """ + return self.model.encode( # type:ignore + torch.tensor( + [ + self.tokenizer.add_padding_tokens( + self.tokenizer.convert_tokens_to_ids( + [self.tokenizer.sos_token] + + self.tokenizer.tokenize(smiles) + + [self.tokenizer.eos_token] + ), + length=self.padding_length, + ) + ] + ) + ) + + def decode(self, z: Point) -> str: + """Decode a catalyst from the latent space. + + Args: + z: a latent space point. + + Returns: + a catalyst in SMILES format. + """ + z = torch.unsqueeze(point_to_tensor(z), dim=0) + reconstructed = self.model.decode(z, max_len=self.maximum_length)[0][ + 0 + ] # type:ignore + reconstructed = self.clean_regex.sub("", reconstructed) + match_ending = self.end_regex.search(reconstructed) + if match_ending: + reconstructed = reconstructed[: match_ending.start()] + return reconstructed + + +class CatalystBindingEnergyPredictor(PropertyPredictor): + """Catalyst binding energy predictor for suzuki reactions.""" + + model: MlpPredictor + + def __init__(self, resources_path: str) -> None: + """Constructs a CatalystBindingEnergyPredictor. + + Args: + resources_path: directory where to find models and configurations. + """ + self.vocabulary_filepath = os.path.join(resources_path, "vocab_combined.csv") + self.checkpoint_filepath = os.path.join( + resources_path, "epoch=199-step=5799.ckpt" + ) + self.tokenizer = SmilesTokenizer(self.vocabulary_filepath) + self.model = cast( + MlpPredictor, + GranularModule.load_from_checkpoint(self.checkpoint_filepath).latent_models[ + 0 + ], + ) + self.model.eval() + + def __call__(self, z: Point) -> float: + """Predict binding energy. + + Args: + z: a latent space point. + + Returns: + the predicted binding energy. + """ + z = point_to_tensor(z) + return self.model(z)[0][0].item() + + +class CatalystGenerator(Generator): + """Catalyst generator.""" + + def __init__( + self, + resources_path: str, + generated_length: int = 100, + number_of_points: int = 10, + number_of_steps: int = 50, + primer_smiles: str = "", + ): + """Constructs catalyst generator. + + Args: + resource_path: directory where to find models and configurations. + generated_length: maximum lenght of the generated molecule. Defaults to 100. + number_of_points: number of optimal points to return. Defaults to 10. + number_of_steps: number of optimization steps. Defaults to 50. + primer_smiles: primer SMILES representation. Default to "", a.k.a., no primer. + """ + self.resources_path = resources_path + self.generated_length = generated_length + self.number_of_points = number_of_points + self.number_of_steps = max(self.number_of_points, number_of_steps) + self.primer_smiles = primer_smiles + self.vae = CatalystVAE( + resources_path, + maximum_length=self.generated_length, + primer_smiles=primer_smiles, + ) + self.predictor = CatalystBindingEnergyPredictor(resources_path) + self.minimum_latent_coordinate = -100.0 + self.maximum_latent_coordinate = 100.0 + + def generate_samples(self, target_energy: Union[float, str]) -> List[str]: + """Generate samples given a target energy. + + Args: + target_energy: target energy value. + + Returns: + catalysts sampled for the target value. + """ + if isinstance(target_energy, str): + logger.warning( + f"target energy ({target_energy}) passed as string, casting to float" + ) + target_energy = float(target_energy) + sampler = GaussianProcessRepresentationsSampler( + {"energy": target_energy}, + property_predictors={"energy": self.predictor}, + representations={"smiles": self.vae}, + bounds={ + "smiles": ( + self.minimum_latent_coordinate, + self.maximum_latent_coordinate, + ) + }, + ) + return [ + sample["smiles"] + for sample in sampler.optimize( + number_of_points=self.number_of_points, + number_of_steps=self.number_of_steps, + ) + ] diff --git a/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/__init__.py b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/__init__.py new file mode 100644 index 000000000..d56253534 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/__init__.py @@ -0,0 +1,14 @@ +"""Controlled Latent attribute Space Sampling initialization.""" +import logging + +from ....extras import EXTRAS_ENABLED + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from .core import PAG, CLaSS, CogMol + + __all__ = ["CLaSS", "CogMol", "PAG"] +else: + logger.warning("install cogmol-inference extras to use CLaSS") diff --git a/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/core.py b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/core.py new file mode 100644 index 000000000..7366e5220 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/core.py @@ -0,0 +1,197 @@ +"""CLaSS Algorithm: PAG and CogMol applications.""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar, Union + +from ....extras import EXTRAS_ENABLED +from ...core import AlgorithmConfiguration, GeneratorAlgorithm # type: ignore +from ...registry import ApplicationsRegistry # type: ignore + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from cog.sample_pipeline import CogMolFiles, read_artifacts_config + from pag.sample_pipeline import PAGFiles + + from .implementation import CogMolGenerator, PAGGenerator + + T = TypeVar("T") + S = TypeVar("S") + Targeted = Callable[[T], Iterable[Any]] + Untargeted = Callable[[], Iterable[Any]] + + class CLaSS(GeneratorAlgorithm[S, T]): + """Controlled Latent attribute Space Sampling (CLaSS) Algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T] = None, + ): + """Instantiate CLaSS ready to generate items. + + Args: + configuration: domain and application + specification, defining types and validations. + target: Optional, in this inistance we will convert to a string. + + Example: + An example for using the CogMol application with this Algorithm:: + + # target protein + MPRO = "SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ" + configuration = CogMol() + algorithm = CLaSS(configuration=configuration, target=MPRO) + items = list(algorithm.sample(1)) + print(items) + + We can also use the PAG application similarly:: + + configuration = PAG() + algorithm = CLaSS(configuration=configuration) + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Union[Untargeted, Targeted[T]]: + """Get the function to sample from generator. + + Args: + configuration: helps to set up the application. + target: target to generate molecules against. + + Returns: + callable generating a list of molecules. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation = configuration.get_class_instance( # type: ignore + resources_path=self.local_artifacts, target=target + ) + return implementation.sample_accepted + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + @ApplicationsRegistry.register_algorithm_application(CLaSS) + class CogMol(AlgorithmConfiguration[str, str]): + """Configuration for CogMol: Target-Specific and Selective Drug Design.""" + + algorithm_type: ClassVar[str] = "controlled_sampling" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + samples_per_round: int = field( + default=200, + metadata=dict( + description="Number of generated samples for acceptance/rejection per round." + ), + ) + max_length: int = field( + default=100, + metadata=dict(description="Maximal number of tokens in generated samples."), + ) + temperature: float = field( + default=1.0, + metadata=dict(description="Temperature of softmax."), + ) + num_proteins_selectivity: int = field( + default=10, + metadata=dict( + description="Number of random samples for measuring off target selectivity for rejection." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + Returns: + target description. + """ + return { + "title": "Protein", + "description": "Primary structure of the target protein as sequence of amino acid characters.", + "type": "string", + } + + def get_class_instance(self, resources_path: str, target: str): + try: + config = read_artifacts_config(resources_path) + bindingdb_date = config["cogmol version information"]["bindingdb_date"] + except KeyError: + bindingdb_date = None + + return CogMolGenerator( + protein_sequence=target, + model_files=CogMolFiles.from_directory_with_config(resources_path), + n_samples_per_round=self.samples_per_round, + device="cpu", + num_proteins_selectivity=self.num_proteins_selectivity, + temp=self.temperature, + max_len=self.max_length, + bindingdb_date=bindingdb_date, + ) + + @ApplicationsRegistry.register_algorithm_application(CLaSS) + class PAG(AlgorithmConfiguration[str, str]): + """Configuration for photoacid generator (PAG) design.""" + + algorithm_type: ClassVar[str] = "controlled_sampling" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + samples_per_round: int = field( + default=200, + metadata=dict( + description="Number of generated samples for acceptance/rejection per round." + ), + ) + max_length: int = field( + default=100, + metadata=dict(description="Maximal number of tokens in generated samples."), + ) + temperature: float = field( + default=1.0, + metadata=dict(description="Temperature of softmax."), + ) + + def get_target_description(self) -> None: + """Untargeted sampling. Always returns None. + + Returns: + None + """ + return None + + def get_class_instance(self, resources_path: str, target: Optional[T] = None): + if target is not None: + raise NotImplementedError + + return PAGGenerator( + model_files=PAGFiles.from_directory_with_config(resources_path), + n_samples_per_round=self.samples_per_round, + device="cpu", + temp=self.temperature, + max_len=self.max_length, + ) + + +else: + logger.warning("install cogmol-inference extras to use CLaSS") diff --git a/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/implementation.py b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/implementation.py new file mode 100644 index 000000000..2fe70530a --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/class_controlled_sampling/implementation.py @@ -0,0 +1,246 @@ +"""CLaSS implementation.""" +import logging +from typing import List, Optional + +import torch + +from ....extras import EXTRAS_ENABLED + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from cog.sample_pipeline import CogMolFiles, get_new_samples, load_vae, mogQ + from cog.z_classifier import ( + load_z_ba_regressor_with_tape_embeddings, + load_zregressor_model, + ) + from pag.sample_pipeline import PAGFiles, filter_heuristic + from pag.z_classifier import load_lumo_clf_model + + class UnsupportedTargetError(RuntimeError): + """Error for target sequence with unknown embedding.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize UnsupportedTargetError. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "UnsupportedTargetError" + self.title = title + self.detail = detail + super().__init__(detail) + + class CogMolGenerator: + def __init__( + self, + protein_sequence: str, + model_files: CogMolFiles, + n_samples_per_round: int = 32, + device: str = "cpu", + # dropout: float = 0.2, + # dropout: Dropout is disabled in eval mode. Defaults to 0.2. + num_proteins_selectivity: int = 10, + temp: float = 1.0, + max_len: int = 100, + bindingdb_date: Optional[str] = None, + ) -> None: + """CogMol generator. + + Args: + protein_sequence: the target sequence for which to generate molecules. + n_samples_per_round: batch size. + model_files: dedicated NamedTuple for artifact filepaths. + device: for example 'cpu'. + num_proteins_selectivity: number of random samples for measuring + selectivity. Defaults to 10. + + Raises: + RuntimeError: in the case extras are disabled. + """ + if not EXTRAS_ENABLED: + raise RuntimeError( + "Can't instantiate CogMolGenerator, extras disabled!" + ) + + self.n_samples_per_round = n_samples_per_round + self.temp = temp + self.max_len = max_len + self.num_proteins_selectivity = num_proteins_selectivity + self.bindingdb_date = bindingdb_date + self.device = device + + self.model = load_vae( + model_files.vae_model, + model_files.vae_config, + model_files.vae_vocab, + device, + ) + + self.clf = load_z_ba_regressor_with_tape_embeddings( + model_path=model_files.ba_model_path, + device=device, + dims=[2048, 1], + dropout=0.2, + ) + self.reg = load_zregressor_model( + model_path=model_files.qed_regressor_model_path, device=device + ) + + self.protein_z_map = torch.load(f=model_files.protein_z_map) # device? + self.protein_emb = self.get_target_embedding( + protein_sequence=protein_sequence + ) + self.protein_sequence = protein_sequence + + # set all models to eval + self.model.eval() + self.clf.eval() + self.reg.eval() + + self.Q_xi_a = mogQ(model_files.mog_model_file, device=device) + + self.Q_xi_a.init_attr_classifiers( + attr_clfs={ + "binding": self.clf, + "qed": self.reg, + "non_binding": self.clf, + }, + clf_targets={ + "binding": 1, + "qed": 0, + "non_binding": 0, + }, + protein_emb_binding=self.protein_emb, + protein_embedding_map=self.protein_z_map, + num_proteins_selectivity=self.num_proteins_selectivity, + ) + + def sample_accepted(self, target: Optional[str] = None) -> List[str]: + if target is not None and target != self.protein_sequence: + self.protein_sequence = target + self.protein_emb = self.get_target_embedding(protein_sequence=target) + self.Q_xi_a.init_attr_classifiers( + attr_clfs={ + "binding": self.clf, + "qed": self.reg, + "non_binding": self.clf, + }, + clf_targets={ + "binding": 1, + "qed": 0, + "non_binding": 0, + }, + protein_emb_binding=self.protein_emb, + protein_embedding_map=self.protein_z_map, + num_proteins_selectivity=self.num_proteins_selectivity, + ) + + samples = get_new_samples( + model=self.model, + Q=self.Q_xi_a, + n_samples=self.n_samples_per_round, + max_len=self.max_len, + temp=self.temp, + ) + return samples[samples["accept_z"] == 1]["smiles"].tolist() + + def get_target_embedding(self, protein_sequence: str) -> torch.Tensor: + """Retrieve embedding of target or raise a dedicated exception. + + Args: + protein_sequence: target amino acid sequence. + + Raises: + UnsupportedTargetError: in case the embedding is not available in + `self.protein_z_map`. + + Returns: + The protein embedding. + """ + try: + return self.protein_z_map[protein_sequence] + except KeyError: + detail = ( + "The provided target is not available in this version: \n" + f"{protein_sequence}" + ) + # prepend a hint on the supported target proteins + if self.bindingdb_date is not None: + detail = ( + f"Only protein sequences published in BindingDB until {self.bindingdb_date} are supported. " + ) + detail + + logger.warning(detail) + raise UnsupportedTargetError( + title="The target protein sequence is not supported.", detail=detail + ) + + class PAGGenerator: + def __init__( + self, + model_files: PAGFiles, + n_samples_per_round: int = 32, + device: str = "cpu", + temp: float = 1.0, + max_len: int = 100, + ) -> None: + """PAG generator. + + Args: + n_samples_per_round: batch size. + model_files: dedicated NamedTuple for artifact filepaths. + device: for example 'cpu'. + + Raises: + RuntimeError: in the case extras are disabled. + """ + if not EXTRAS_ENABLED: + raise RuntimeError("Can't instantiate PAGGenerator, extras disabled!") + + self.n_samples_per_round = n_samples_per_round + self.temp = temp + self.max_len = max_len + self.device = device + + self.model = load_vae( + model_files.vae_model, + model_files.vae_config, + model_files.vae_vocab, + device, + ) + + self.clf = load_lumo_clf_model(model_path=model_files.lumo_clf_model_path) + + # set all models to eval + self.model.eval() + self.clf.eval() + + self.Q_xi_a = mogQ(model_files.mog_model_file, device=device) + + self.Q_xi_a.init_attr_classifiers( + attr_clfs={ + "LUMO": self.clf, + }, + clf_targets={ + "LUMO": 1, + }, + ) + + def sample_accepted(self) -> List[str]: + samples = get_new_samples( + model=self.model, + Q=self.Q_xi_a, + n_samples=self.n_samples_per_round, + max_len=self.max_len, + temp=self.temp, + ) + # TODO: allow user input to heuristics? + samples = filter_heuristic(samples) + return samples[samples.accept_z & samples.accept].smiles.tolist() + + +else: + logger.warning("install cogmol-inference extras to use CLaSS") diff --git a/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/__init__.py b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/__init__.py new file mode 100644 index 000000000..d217a1fe4 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/__init__.py @@ -0,0 +1,5 @@ +"""PaccMannGP initialization.""" + +from .core import PaccMannGP, PaccMannGPGenerator + +__all__ = ["PaccMannGP", "PaccMannGPGenerator"] diff --git a/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/core.py b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/core.py new file mode 100644 index 000000000..4a4339b06 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/core.py @@ -0,0 +1,253 @@ +"""PaccMann\\ :superscript:`GP` Algorithm. + +PaccMann\\ :superscript:`GP` generation is conditioned via gaussian processes. +""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from typing_extensions import Protocol, runtime_checkable + +from ....domains.materials import SmallMolecule, validate_molecules +from ....exceptions import InvalidItem +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import GPConditionalGenerator + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=SmallMolecule) +Targeted = Callable[[T], Iterable[Any]] + + +class PaccMannGP(GeneratorAlgorithm[S, T]): + """PaccMann\\ :superscript:`GP` Algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate PaccMannGP ready to generate items. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating small molecules (SMILES) with high affinity + for a target protein:: + + configuration = PaccMannGPGenerator() + target = { + "qed": {"weight": 1.0}, + "molwt": {"target": 200}, + "sa": {"weight": 2.0}, + "affinity": {"protein": "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT"} + } + paccmann_gp = PaccMannGP(configuration=configuration, target=target) + items = list(paccmann_gp.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to sample batches via PaccMannGP's GPConditionalGenerator. + + Args: + configuration: helps to set up specific application of PaccMannGP. + target: context or condition for the generation. + + Returns: + callable with target generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: GPConditionalGenerator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate_batch + + def validate_configuration( + self, configuration: AlgorithmConfiguration[S, T] + ) -> AlgorithmConfiguration[S, T]: + @runtime_checkable + class AnyPaccMannGPConfiguration(Protocol): + """Protocol for PaccMannGP configurations.""" + + def get_conditional_generator( + self, resources_path: str + ) -> GPConditionalGenerator: + ... + + def validate_item(self, item: Any) -> S: + ... + + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AnyPaccMannGPConfiguration) + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(PaccMannGP) +class PaccMannGPGenerator(AlgorithmConfiguration[SmallMolecule, Any]): + """ + Configuration to generate compounds controlling molecules properties. + + Implementation from the paper: https://doi.org/10.1021/acs.jcim.1c00889. + """ + + algorithm_type: ClassVar[str] = "controlled_sampling" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + temperature: float = field( + default=1.4, + metadata=dict( + description="Temperature parameter for the softmax sampling in decoding." + ), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + limit: float = field( + default=5.0, + metadata=dict(description="Hypercube limits in the latent space."), + ) + acquisition_function: str = field( + default="EI", + metadata=dict( + description=( + "Acquisition function used in the Gaussian process. " + "More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html." + ) + ), + ) + number_of_steps: int = field( + default=32, + metadata=dict(description="Number of steps for an optmization round."), + ) + number_of_initial_points: int = field( + default=16, + metadata=dict(description="Number of initial points evaluated."), + ) + initial_point_generator: str = field( + default="random", + metadata=dict( + description=( + "Scheme to generate initial points. " + "More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html." + ) + ), + ) + seed: int = field( + default=42, + metadata=dict( + description="Seed used for random number generation in the optimizer." + ), + ) + number_of_optimization_rounds: int = field( + default=1, + metadata=dict(description="Maximum number of optimization rounds."), + ) + sampling_variance: float = field( + default=0.1, + metadata=dict( + description="Variance of the Gaussian noise applied during sampling from the optimal point." + ), + ) + samples_for_evaluation: int = field( + default=4, + metadata=dict( + description="Number of samples averaged for each minimization function evaluation." + ), + ) + maximum_number_of_sampling_steps: int = field( + default=32, + metadata=dict( + description="Maximum number of sampling steps in an optimization round." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Scoring functions with parameters", + "description": "Scoring functions will be used to generate a score for the generated molecules.", + "type": "object", + } + + def get_conditional_generator(self, resources_path: str) -> GPConditionalGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return GPConditionalGenerator( + resources_path=resources_path, + temperature=self.temperature, + generated_length=self.generated_length, + batch_size=self.batch_size, + limit=self.limit, + acquisition_function=self.acquisition_function, + number_of_steps=self.number_of_steps, + number_of_initial_points=self.number_of_initial_points, + initial_point_generator=self.initial_point_generator, + seed=self.seed, + number_of_optimization_rounds=self.number_of_optimization_rounds, + sampling_variance=self.sampling_variance, + samples_for_evaluation=self.samples_for_evaluation, + maximum_number_of_sampling_steps=self.maximum_number_of_sampling_steps, + ) + + def validate_item(self, item: str) -> SmallMolecule: + """Check that item is a valid SMILES. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated SMILES. + """ + ( + molecules, + _, + ) = validate_molecules([item]) + if molecules[0] is None: + raise InvalidItem( + title="InvalidSMILES", + detail=f'rdkit.Chem.MolFromSmiles returned None for "{item}"', + ) + return SmallMolecule(item) diff --git a/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/implementation.py b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/implementation.py new file mode 100644 index 000000000..90b0a56d4 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/paccmann_gp/implementation.py @@ -0,0 +1,244 @@ +"""Implementation of PaccMann^GP conditional generator.""" + +import json +import logging +import os +from typing import Any, Dict, List, Optional, Union + +import torch +from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE +from paccmann_chemistry.utils.search import SamplingSearch +from paccmann_gp.affinity_minimization import AffinityMinimization +from paccmann_gp.combined_minimization import CombinedMinimization +from paccmann_gp.gp_optimizer import GPOptimizer +from paccmann_gp.mw_minimization import MWMinimization +from paccmann_gp.qed_minimization import QEDMinimization +from paccmann_gp.sa_minimization import SAMinimization +from paccmann_gp.smiles_generator import SmilesGenerator +from paccmann_predictor.models import MODEL_FACTORY +from pytoda.proteins.protein_language import ProteinLanguage +from pytoda.smiles.smiles_language import SMILESLanguage + +from ....frameworks.torch import device_claim + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +MINIMIZATION_FUNCTIONS = { + "qed": QEDMinimization, + "sa": SAMinimization, + "molwt": MWMinimization, + "affinity": AffinityMinimization, +} + + +class GPConditionalGenerator: + """Conditional generator as implemented in https://doi.org/10.1021/acs.jcim.1c00889.""" + + def __init__( + self, + resources_path: str, + temperature: float = 1.4, + generated_length: int = 100, + batch_size: int = 32, + limit: float = 5.0, + acquisition_function: str = "EI", + number_of_steps: int = 32, + number_of_initial_points: int = 16, + initial_point_generator: str = "random", + seed: int = 42, + number_of_optimization_rounds: int = 1, + sampling_variance: float = 0.1, + samples_for_evaluation: int = 4, + maximum_number_of_sampling_steps: int = 32, + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """Initialize the conditional generator. + + Args: + resources_path: directory where to find models and parameters. + temperature: temperature parameter for the softmax sampling in decoding. Defaults to 1.4. + generated_length: maximum length in tokens of the generated molcules (relates to the SMILES length). Defaults to 100. + batch_size: batch size used for the generative model sampling. Defaults to 16. + limit: hypercube limits in the latent space. Defaults to 5.0. + acquisition_function: acquisition function used in the Gaussian process. Defaults to "EI". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html. + number_of_steps: number of steps for an optmization round. Defaults to 32. + number_of_initial_points: number of initial points evaluated. Defaults to 16. + initial_point_generator: scheme to generate initial points. Defaults to "random". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html. + seed: seed used for random number generation in the optimizer. Defaults to 42. + number_of_optimization_rounds: maximum number of optimization rounds. Defaults to 1. + sampling_variance: variance of the Gaussian noise applied during sampling from the optimal point. Defaults to 0.1. + samples_for_evaluation: number of samples averaged for each minimization function evaluation. Defaults to 4. + maximum_number_of_sampling_steps: maximum number of sampling steps in an optmization round. Defaults to 32. + device: . Defaults to None, a.k.a, picking a default one ("gpu" if present, "cpu" otherwise). + """ + # device + self.device = device_claim(device) + # setting sampling parameters + self.temperature = temperature + self.generated_length = generated_length + self.batch_size = batch_size + # setting VAE parameters + self.svae_params = dict() + with open(os.path.join(resources_path, "vae_model_params.json"), "r") as f: + self.svae_params.update(json.load(f)) + smiles_language = SMILESLanguage.load( + os.path.join(resources_path, "selfies_language.pkl") + ) + # initialize encoder, decoder, testVAE, and GP_generator_MW + self.gru_encoder = StackGRUEncoder(self.svae_params) + self.gru_decoder = StackGRUDecoder(self.svae_params) + self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder) + self.gru_vae.load_state_dict( + torch.load( + os.path.join(resources_path, "vae_weights.pt"), + map_location=self.device, + ) + ) + self.gru_vae._associate_language(smiles_language) + self.gru_vae.eval() + self.smiles_generator = SmilesGenerator( + self.gru_vae, + search=SamplingSearch(temperature=self.temperature), + generated_length=self.generated_length, + ) + self.latent_dim = self.gru_decoder.latent_dim + # setting affinity predictor parameters + with open(os.path.join(resources_path, "mca_model_params.json")) as f: + self.predictor_params = json.load(f) + self.affinity_predictor = MODEL_FACTORY["bimodal_mca"](self.predictor_params) + self.affinity_predictor.load( + os.path.join(resources_path, "mca_weights.pt"), + map_location=self.device, + ) + affinity_protein_language = ProteinLanguage.load( + os.path.join(resources_path, "protein_language.pkl") + ) + affinity_smiles_language = SMILESLanguage.load( + os.path.join(resources_path, "smiles_language.pkl") + ) + self.affinity_predictor._associate_language(affinity_smiles_language) + self.affinity_predictor._associate_language(affinity_protein_language) + self.affinity_predictor.eval() + # setting optimizer parameters + self.limit = limit + self.acquisition_function = acquisition_function + self.number_of_initial_points = number_of_initial_points + if number_of_steps < self.number_of_initial_points: + logger.warning( + "number of initial points is larger than number of steps " + f"({self.number_of_initial_points}/{number_of_steps}). " + f"Resetting number of steps to {self.number_of_initial_points}." + ) + self.number_of_steps = self.number_of_initial_points + else: + self.number_of_steps = number_of_steps + self.initial_point_generator = initial_point_generator + self.seed = seed + self.number_of_optimization_rounds = number_of_optimization_rounds + self.sampling_variance = sampling_variance + self.samples_for_evaluation = samples_for_evaluation + self.maximum_number_of_sampling_steps = maximum_number_of_sampling_steps + + def target_to_minimization_function( + self, target: Union[Dict[str, Dict[str, Any]], str] + ) -> CombinedMinimization: + """Use the target to configure a minimization function. + + Args: + target: dictionary or JSON string describing the optimization target. + + Returns: + a minimization function. + """ + if isinstance(target, str): + target_dictionary = json.loads(target) + elif isinstance(target, dict): + target_dictionary = target + else: + raise ValueError( + f"{target} of type {type(target)} is not supported: provide 'str' or 'Dict[str, Dict[str, Any]]'" + ) + minimization_functions = [] + weights = [] + for minimization_function_name, parameters in target_dictionary.items(): + weight = 1.0 + if "weight" in parameters: + weight = parameters.pop("weight") + function_parameters = { + **parameters, + **{ + "batch_size": self.samples_for_evaluation, + "smiles_decoder": self.smiles_generator, + }, + } + minimization_function = MINIMIZATION_FUNCTIONS[minimization_function_name] + if minimization_function_name == "affinity": + function_parameters["affinity_predictor"] = self.affinity_predictor + minimization_functions.append(minimization_function(**function_parameters)) + weights.append(weight) + return CombinedMinimization( + minimization_functions=minimization_functions, + batch_size=1, + function_weights=weights, + ) + + def generate_batch(self, target: Any) -> List[str]: + """Generate molecules given a target. + + Args: + target: dictionary or JSON string describing the optimization target. + + Returns: + a list of molecules as SMILES string. + """ + # make sure the seed is transformed to avoid redundancy over multiple calls (using Knuth multiplicative hashing) + self.seed = self.seed * 2654435761 % 2 ** 32 + logger.info(f"configuring optimization for target: {target}") + # target configuration + self.target = target + self.minimization_function = self.target_to_minimization_function(self.target) + # optimizer configuration + self.target_optimizer = GPOptimizer(self.minimization_function.evaluate) + optimization_parameters = dict( + dimensions=[(-self.limit, self.limit)] * self.latent_dim, + acq_func=self.acquisition_function, + n_calls=self.number_of_steps, + n_initial_points=self.number_of_initial_points, + initial_point_generator=self.initial_point_generator, + random_state=self.seed, + ) + logger.info( + f"running optimization with the following parameters: {optimization_parameters}" + ) + smiles_set = set() + logger.info( + f"running at most {self.number_of_optimization_rounds} optmization rounds" + ) + for optimization_round in range(self.number_of_optimization_rounds): + logger.info(f"starting round {optimization_round + 1}") + optimization_parameters["random_state"] += optimization_round # type:ignore + res = self.target_optimizer.optimize(optimization_parameters) + latent_point = torch.tensor([[res.x]]) + smiles_set_per_round = set() + + logger.info(f"starting sampling for {optimization_round + 1}") + for _ in range(self.maximum_number_of_sampling_steps): + generated_smiles = self.smiles_generator.generate_smiles( + latent_point.repeat(1, self.batch_size, 1) + + torch.cat( + ( + torch.zeros(1, 1, self.latent_dim), + (self.sampling_variance ** 0.5) + * torch.randn(1, self.batch_size - 1, self.latent_dim), + ), + dim=1, + ) + ) + smiles_set_per_round.update(set(generated_smiles)) + smiles_set.update(smiles_set_per_round) + logger.info(f"completing round {optimization_round + 1}") + logger.info(f"generated {len(smiles_set)} molecules in the current run") + return list( + [molecule_smiles for molecule_smiles in smiles_set if molecule_smiles] + ) diff --git a/src/gt4sd/algorithms/controlled_sampling/tests/__init__.py b/src/gt4sd/algorithms/controlled_sampling/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/controlled_sampling/tests/test_advanced_manufacturing.py b/src/gt4sd/algorithms/controlled_sampling/tests/test_advanced_manufacturing.py new file mode 100644 index 000000000..15ac00625 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/tests/test_advanced_manufacturing.py @@ -0,0 +1,155 @@ +"""AdvancedManufacturing tests.""" + +import pickle +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.controlled_sampling.advanced_manufacturing import ( + AdvancedManufacturing, + CatalystGenerator, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + CatalystGenerator, + "controlled_sampling", + "materials", + AdvancedManufacturing.__name__, + ) + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (CatalystGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (CatalystGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, example_target, algorithm", + [ + ( + CatalystGenerator, + 10.0, + AdvancedManufacturing, + ), + ( + CatalystGenerator, + "10.0", + AdvancedManufacturing, + ), + ], +) +def test_generation_via_import(config, example_target, algorithm): + advanced_manufacturing = algorithm( + configuration=config(number_of_steps=10, number_of_points=10), + target=example_target, + ) + items = list(advanced_manufacturing.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "algorithm_application, target", + [ + ( + CatalystGenerator.__name__, + 10.0, + ), + ( + CatalystGenerator.__name__, + "10.0", + ), + ], +) +def test_generation_via_registry(target, algorithm_application): + advanced_manufacturing = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type="controlled_sampling", + domain="materials", + algorithm_name=AdvancedManufacturing.__name__, + algorithm_application=algorithm_application, + generated_length=100, + number_of_steps=10, + number_of_points=10, + ) + items = list(advanced_manufacturing.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "config_class", + [ + (CatalystGenerator), + ], +) +def test_configuration_pickable(config_class: Type[AlgorithmConfiguration]): + # implementation + obj = config_class(algorithm_version="test") + + # --- + import inspect + + inspect.getmodule(config_class) + # --- + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj + + # registered + Config = ApplicationsRegistry.get_application( + algorithm_type="controlled_sampling", + domain="materials", + algorithm_name=AdvancedManufacturing.__name__, + algorithm_application=config_class.__name__, + ).configuration_class + + obj = Config(algorithm_version="test") + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj diff --git a/src/gt4sd/algorithms/controlled_sampling/tests/test_class_controlled_sampling.py b/src/gt4sd/algorithms/controlled_sampling/tests/test_class_controlled_sampling.py new file mode 100644 index 000000000..49ee47ce9 --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/tests/test_class_controlled_sampling.py @@ -0,0 +1,200 @@ +"""CLaSS tests.""" + +import pickle +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry +from gt4sd.extras import EXTRAS_ENABLED + +if not EXTRAS_ENABLED: + pytest.skip("Extras from custom PyPI disabled", allow_module_level=True) +else: + from gt4sd.algorithms.controlled_sampling.class_controlled_sampling import ( + PAG, + CLaSS, + CogMol, + ) + from gt4sd.algorithms.controlled_sampling.class_controlled_sampling.implementation import ( + UnsupportedTargetError, + ) + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +MPRO = "SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ" + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + CogMol, + "controlled_sampling", + "materials", + CLaSS.__name__, + ), + ( + PAG, + "controlled_sampling", + "materials", + CLaSS.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (CogMol), + (PAG), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (CogMol), + (PAG), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, example_target, algorithm, kwargs", + [ + ( + CogMol, + MPRO, + CLaSS, + { + "samples_per_round": 173, + "max_length": 40, + "temperature": 0.8, + "num_proteins_selectivity": 20, + }, + ), + ( + PAG, + None, + CLaSS, + { + "samples_per_round": 173, + "max_length": 40, + "temperature": 0.8, + }, + ), + ], +) +def test_generation_via_import(config, example_target, algorithm, kwargs): + class_sampling = algorithm( + configuration=config(**kwargs), + target=example_target, + ) + items = list(class_sampling.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "algorithm_application, target", + [ + ( + CogMol.__name__, + MPRO, + ), + ( + PAG.__name__, + None, + ), + ], +) +def test_generation_via_registry(target, algorithm_application): + class_sampling = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type="controlled_sampling", + domain="materials", + algorithm_name=CLaSS.__name__, + algorithm_application=algorithm_application, + ) + items = list(class_sampling.sample(5)) + assert len(items) == 5 + + +def test_unsupported_target(algorithm_application=CogMol.__name__, target=MPRO): + invalid_target = target[:30] # assuming this makes it invalid + + # on construction + with pytest.raises(UnsupportedTargetError): + ApplicationsRegistry.get_application_instance( + target=invalid_target, + algorithm_type="controlled_sampling", + domain="materials", + algorithm_name=CLaSS.__name__, + algorithm_application=algorithm_application, + ) + + # on sampling with changed targed + config = CogMol() + implementation = config.get_class_instance( # type: ignore + resources_path=config.ensure_artifacts(), target=target + ) + with pytest.raises(UnsupportedTargetError): + implementation.sample_accepted(invalid_target) + + +@pytest.mark.parametrize("config_class", [(CogMol), (PAG)]) +def test_configuration_pickable(config_class: Type[AlgorithmConfiguration]): + # implementation + obj = config_class(algorithm_version="test") + + # --- + import inspect + + inspect.getmodule(config_class) + # --- + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj + + # registered + Config = ApplicationsRegistry.get_application( + algorithm_type="controlled_sampling", + domain="materials", + algorithm_name=CLaSS.__name__, + algorithm_application=config_class.__name__, + ).configuration_class + + obj = Config(algorithm_version="test") + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj diff --git a/src/gt4sd/algorithms/controlled_sampling/tests/test_paccmann_gp.py b/src/gt4sd/algorithms/controlled_sampling/tests/test_paccmann_gp.py new file mode 100644 index 000000000..f0008cd8e --- /dev/null +++ b/src/gt4sd/algorithms/controlled_sampling/tests/test_paccmann_gp.py @@ -0,0 +1,132 @@ +"""CLaSS tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.controlled_sampling.paccmann_gp import ( + PaccMannGP, + PaccMannGPGenerator, +) +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry + +TARGET = { + "qed": {"weight": 1.0}, + "molwt": {"target": 200}, + "sa": {"weight": 2.0}, + "affinity": {"protein": "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTT"}, +} +PARAMETERS = { + "number_of_steps": 8, + "number_of_initial_points": 4, + "number_of_optimization_rounds": 1, + "samples_for_evaluation": 2, + "maximum_number_of_sampling_steps": 4, +} + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + PaccMannGPGenerator, + "controlled_sampling", + "materials", + PaccMannGP.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (PaccMannGPGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (PaccMannGPGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm, algorithm_parameters", + [ + (PaccMannGPGenerator, PaccMannGP, PARAMETERS), + ], +) +def test_generation_via_import(config, algorithm, algorithm_parameters): + parameters = { + "batch_size": 1, + } + for param, value in algorithm_parameters.items(): + parameters[param] = value + config = config(**parameters) + algorithm = algorithm(configuration=config, target=TARGET) + items = list(algorithm.sample(1)) + assert len(items) == 1 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name, algorithm_parameters", + [ + ( + PaccMannGPGenerator.__name__, + "controlled_sampling", + "materials", + PaccMannGP.__name__, + PARAMETERS, + ), + ], +) +def test_generation_via_registry( + algorithm_type, + domain, + algorithm_name, + algorithm_application, + algorithm_parameters, +): + parameters = { + "target": TARGET, + "algorithm_type": algorithm_type, + "domain": domain, + "algorithm_name": algorithm_name, + "algorithm_application": algorithm_application, + "batch_size": 1, + } + for param, value in algorithm_parameters.items(): + parameters[param] = value + algorithm = ApplicationsRegistry.get_application_instance(**parameters) + items = list(algorithm.sample(1)) + assert len(items) == 1 diff --git a/src/gt4sd/algorithms/core.py b/src/gt4sd/algorithms/core.py new file mode 100644 index 000000000..e4133ea7d --- /dev/null +++ b/src/gt4sd/algorithms/core.py @@ -0,0 +1,493 @@ +"""Bases classes and core code used across multiple algorithms.""" + +from __future__ import annotations + +import logging +import os +import signal +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from functools import partial +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + Iterable, + Iterator, + Optional, + Set, + Type, + TypeVar, + Union, +) + +from ..configuration import ( + GT4SDConfiguration, + get_algorithm_subdirectories_in_cache, + get_algorithm_subdirectories_with_s3, + get_cached_algorithm_path, + sync_algorithm_with_s3, +) +from ..exceptions import InvalidItem, S3SyncError, SamplingError + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +gt4sd_configuration_instance = GT4SDConfiguration.get_instance() + +# leave typing generic for algorithm implementation +S = TypeVar("S") # used for generated items +T = TypeVar("T") # used for target of generation +U = TypeVar("U") # used for additional context (e.g. part of target definition) + +# callable taking a target +Targeted = Callable[[T], Iterable[Any]] +# callable not taking any target +Untargeted = Callable[[], Iterable[Any]] + + +class GeneratorAlgorithm(ABC, Generic[S, T]): + """Interface for automated generation via an :class:`AlgorithmConfiguration`.""" + + generator: Union[Untargeted, Targeted[T]] + target: Optional[T] + + #: The maximum amount of time we should let the algorithm run + max_runtime: int = gt4sd_configuration_instance.gt4sd_max_runtime + #: The maximum number of samples a user can try to run in one go + max_samples: int = gt4sd_configuration_instance.gt4sd_max_number_of_samples + + generate: Untargeted + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T] = None, + ): + """Targeted or untargeted generation. + + Args: + configuration: application specific helper that allows to setup the + generator. + target: context or condition for the generation. Defaults to None. + """ + logger.info( + f"runnning {self.__class__.__name__} with configuration={configuration}" + ) + generator = self.get_generator(configuration, target) + setattr( + self, + "generate", + self._setup_untargeted_generator( + configuration=configuration, generator=generator, target=target + ), + ) + + @abstractmethod + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Union[Untargeted, Targeted[T]]: + """Set up the detail implementation using the configuration. + + Note: + This is the major method to implement in child classes, it is called + at instantiation of the GeneratorAlgorithm and must return a callable: + + - Either :obj:`Untargeted`: the callable is taking no arguements, + and target has to be :obj:`None`. + - Or :obj:`Targeted`: the callable with the target (but not :obj:`None`). + + Args: + configuration: application specific helper that allows to setup the + generator. + target: context or condition for the generation. Defaults to None. + + Returns: + generator, the detail implementation used for generation. + If the target is None, the generator is assumed to be untargeted. + """ + + def _setup_untargeted_generator( + self, + configuration: AlgorithmConfiguration[S, T], + generator: Union[Untargeted, Targeted[T]], + target: Optional[T] = None, + ) -> Untargeted: + """Targeted or untargeted generation. + + Args: + configuration: application specific helper that allows to setup the + generator. + generator: the detail implementation used for generation. + If the target is None, the generator is assumed to be untargeted. + target: context or condition for the generation. Defaults to None. + """ + self.configuration = configuration + self.target = target + self.generator = generator + + if target is None: + return self.generator # type: ignore + else: + return partial(self.generator, self.target) # type: ignore + + def timeout(self, signum, frame): + raise TimeoutError( + "Alarm signal received, probably because a signal.alarm timed out.", + ) + + def sample(self, number_of_items: int = 100) -> Iterator[S]: + """Generate a number of unique and valid items. + + Filters duplicate items and iterates batches of generated items to reach + the desired number of samples, but the number of yielded items is not + guaranteed: + In case the generate method does not create new samples for + GT4SD_MAX_NUMBER_OF_STUCK_CALLS times, it will terminate the + sampling process. + + Args: + number_of_items: number of items to generate. + Defaults to 100. + + Raises: + SamplingError: when requesting too many items or when no items were yielded. + The later happens in case of not generating samples in a number of calls + and when taking longer than the allowed time limit. + + Yields: + the items. + """ + + if number_of_items > self.max_samples: + detail = ( + f"{number_of_items} is too many items to generate, " + f"must be under {self.max_samples+1} samples." + ) + logger.warning(detail) + raise SamplingError(title="Exceeding max_samples", detail=detail) + + def raise_if_none_sampled(items: set, detail: str): + """If exiting early there should be at least one generated item. + + Args: + items: to check if it's empty. + detail: error message in case the exception is raised. + + Raises: + SamplingError: using the given detail. + """ + if len(items) == 0: + raise SamplingError( + title="No samples generated", + detail="No samples generated." + detail, + ) + + item_set = set() + stuck_counter = 0 + item_set_length = 0 + signal.signal(signal.SIGALRM, self.timeout) + signal.alarm(self.max_runtime) + try: + while True: + generated_items = self.generate() # type:ignore + for item in generated_items: + if item in item_set: + continue + else: + try: + valid_item = self.configuration.validate_item(item) + yield valid_item + item_set.add(item) + if len(item_set) == number_of_items: + signal.alarm(0) + return + except InvalidItem as error: + logger.debug( + f"item {item} could not be validated, " + f"raising {error.title}: {error.detail}" + ) + continue + # make sure we don't keep sampling more than a given number of times, + # in case no new items are generated. + if len(item_set) == item_set_length: + stuck_counter += 1 + else: + stuck_counter = 0 + if ( + stuck_counter + >= gt4sd_configuration_instance.gt4sd_max_number_of_stuck_calls + ): + detail = f"no novel samples generared for more than {gt4sd_configuration_instance.gt4sd_max_number_of_stuck_calls} cycles" + logger.warning(detail + ", exiting") + signal.alarm(0) + raise_if_none_sampled(items=item_set, detail=detail) + return + item_set_length = len(item_set) + except TimeoutError: + detail = f"Samples took longer than {self.max_runtime} seconds to generate" + logger.warning(detail + ", exiting") + raise_if_none_sampled(items=item_set, detail=detail) + signal.alarm(0) + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + """Overload to validate the a configuration for the algorithm. + + Args: + configuration: the algorithm configuration. + + Raises: + InvalidAlgorithmConfiguration: in case the configuration for the algorithm is invalid. + + Returns: + the validated configuration. + """ + logger.info("no parameters validation") + return configuration + + +@dataclass +class AlgorithmConfiguration(Generic[S, T]): + """Algorithm parameter definitions and implementation setup. + + The signature of this class constructor (given by the instance attributes) is used + for the REST API and needs to be serializable. + + Child classes will add additional instance attributes to configure their respective + algorithms. This will require setting default values for all of the attributes defined + here. + However, the values for :attr:`algorithm_name` and :attr:`algorithm_application` + are set the registering decorator. + + This strict setup has the following desired effects: + + - Ease child implementation. For example:: + + from typing import ClassVar + + from gt4sd.algorithms.registry import ApplicationsRegistry + from gt4sd.algorithms.core import AlgorithmConfiguration + + @ApplicationsRegistry.register_algorithm_application(ChildOfGeneratorAlgorithm) + class ConfigurationForChildOfGeneratorAlgorithm(AlgorithmConfiguration): + algorithm_type: ClassVar[str] = 'generation' + domain: ClassVar[str] = 'materials' + algorithm_version: str = 'version3.14' + actual_parameter: float = 1.61 + + # no __init__ definition required + + + 2. Retrieve the algorithm and configuration easily (via the four class attributes) + from the :class:`ApplicationsRegistry`. + For example:: + + from gt4sd.algorithms.registry import ApplicationsRegistry + + application = ApplicationsRegistry.get_application( + algorithm_type='generation', + domain='materials', + algorithm_name='ChildOfGeneratorAlgorithm', + algorithm_application='ConfigurationForChildOfGeneratorAlgorithm', + ) + Algorithm = application.algorithm_class + Configuration = application.configuration_class + + 3. An effortless validation at instantiation via :mod:`pydantic`. + + 4. An effortless mapping to artifacts on s3, see :meth:`ensure_artifacts`. + + Todo: + show how to register a configuration manually (in case it applies to multiple + algorithms and/or applications) + + """ + + #: General type of generative algorithm. + algorithm_type: ClassVar[str] + #: General application domain. Hints at input/output types. + domain: ClassVar[str] + #: Name of the algorithm to use with this configuration. + #: + #: Will be set when registering to :class:`ApplicationsRegistry` + algorithm_name: ClassVar[str] + #: Unique name for the application that is the use of this + #: configuration together with a specific algorithm. + #: + #: Will be set when registering to :class:`ApplicationsRegistry`, + #: but can be given by direct registration (See :meth:`register_algorithm_application`) + algorithm_application: ClassVar[str] + + #: To differentiate between different versions of an application. + #: + #: There is no imposed naming convention. + algorithm_version: str = "" + + def get_target_description(self) -> Optional[Dict[str, str]]: + """Get description of the target for generation. + + Returns: + target description, returns None in case no target is used. + """ + return { + "title": "Target for generation", + "type": "object", + "description": "Optional target for generation.", + } + + def to_dict(self) -> Dict[str, Any]: + """Represent the configuration as a dictionary. + + Returns: + description of the configuration with parameters description. + """ + base_configuration_fields_set = set( + AlgorithmConfiguration.__dataclass_fields__.keys() # type:ignore + ) + application_configuration_dict = dict(description=self.__doc__) + for name, base_description in self.__pydantic_model__.schema()[ # type:ignore + "properties" + ].items(): + if name not in base_configuration_fields_set: + description = dict( + getattr( + self.__dataclass_fields__[name], "metadata", {} # type:ignore + ) + ) + description.update(base_description) + if "default" in description: + description["optional"] = True + else: + description["optional"] = False + application_configuration_dict[name] = description # type:ignore + return application_configuration_dict + + def validate_item(self, item: Any) -> S: + """Overload to validate an item. + + Args: + item: validate an item. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + S: the validated item. + """ + # no item validation + return item + + def ensure_artifacts(self) -> str: + """The artifacts matching the path defined by class attributes are downloaded. + + That is all objects under ``algorithm_type/algorithm_name/algorithm_application/algorithm_version`` + in the bucket are downloaded. + + Returns: + str: the common local path of the matching artifacts. + """ + prefix = os.path.join( + self.algorithm_type, + self.algorithm_name, + self.algorithm_application, + self.algorithm_version, + ) + try: + local_path = sync_algorithm_with_s3(prefix) + except (KeyError, S3SyncError) as error: + logger.info( + f"searching S3 raised {error.__class__.__name__}, using local cache only." + ) + logger.debug(error) + local_path = get_cached_algorithm_path(prefix) + if not os.path.isdir(local_path): + raise OSError( + f"artifacts directory {local_path} does not exist locally, and syncing with s3 failed: {error}" + ) + + return local_path + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + S3 is searched as well as the local cache is searched for matching versions. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + + prefix = os.path.join( + cls.algorithm_type, cls.algorithm_name, cls.algorithm_application + ) + try: + versions = get_algorithm_subdirectories_with_s3(prefix) + except (KeyError, S3SyncError) as error: + logger.info( + f"searching S3 raised {error.__class__.__name__}, using local cache only." + ) + logger.debug(error) + versions = set() + versions = versions.union(get_algorithm_subdirectories_in_cache(prefix)) + return versions + + +def get_configuration_class_with_attributes( + klass: Type[AlgorithmConfiguration], +) -> Type[AlgorithmConfiguration]: + """Get AlgorithmConfiguration with set attributes. + + Args: + klass: a class to be used to extract attributes from. + + Returns: + a class with the attributes set. + """ + configuration_class = deepcopy(AlgorithmConfiguration) + setattr(configuration_class, "algorithm_type", klass.algorithm_type) + setattr(configuration_class, "algorithm_name", klass.algorithm_name) + setattr(configuration_class, "algorithm_application", klass.__name__) + setattr(configuration_class, "algorithm_version", klass.algorithm_version) + return configuration_class + + +class PropertyPredictor(ABC, Generic[S, U]): + """WIP""" + + def __init__(self, context: U) -> None: + """Property predictor to investigate items. + + Args: + context: the context in which a property of an item can be + computed or checked is very application specific. + """ + self.context = context + # or pass these with methods? + + @abstractmethod + def satisfies(self, item: S) -> bool: + """Check whether an item satisfies given requirements. + + Args: + item: the item to check. + + Returns: + bool: + """ + + def compute(self, item: S) -> Any: + """Compute some metric/property on an item. + + Args: + item: the item to compute a metric on. + + Returns: + Any: the computed metric/property. + """ diff --git a/src/gt4sd/algorithms/generation/__init__.py b/src/gt4sd/algorithms/generation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/generation/hugging_face/__init__.py b/src/gt4sd/algorithms/generation/hugging_face/__init__.py new file mode 100644 index 000000000..5a0e30554 --- /dev/null +++ b/src/gt4sd/algorithms/generation/hugging_face/__init__.py @@ -0,0 +1,21 @@ +"""HuggingFaceGenerationAlgorithm initialization.""" + +from .core import ( + HuggingFaceCTRLGenerator, + HuggingFaceGenerationAlgorithm, + HuggingFaceGPT2Generator, + HuggingFaceOpenAIGPTGenerator, + HuggingFaceTransfoXLGenerator, + HuggingFaceXLMGenerator, + HuggingFaceXLNetGenerator, +) + +__all__ = [ + "HuggingFaceGenerationAlgorithm", + "HuggingFaceXLMGenerator", + "HuggingFaceCTRLGenerator", + "HuggingFaceGPT2Generator", + "HuggingFaceOpenAIGPTGenerator", + "HuggingFaceXLNetGenerator", + "HuggingFaceTransfoXLGenerator", +] diff --git a/src/gt4sd/algorithms/generation/hugging_face/core.py b/src/gt4sd/algorithms/generation/hugging_face/core.py new file mode 100644 index 000000000..648301d73 --- /dev/null +++ b/src/gt4sd/algorithms/generation/hugging_face/core.py @@ -0,0 +1,314 @@ +"""HuggingFace generation algorithm.""" + +import logging +from dataclasses import field +from typing import ClassVar, Dict, Optional, Set, TypeVar + +from ...core import ( + AlgorithmConfiguration, + GeneratorAlgorithm, + Untargeted, + get_configuration_class_with_attributes, +) +from ...registry import ApplicationsRegistry +from .implementation import MODEL_TYPES, Generator + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = type(None) +S = TypeVar("S", bound=str) + + +class HuggingFaceGenerationAlgorithm(GeneratorAlgorithm[S, T]): + def __init__( + self, configuration: AlgorithmConfiguration, target: Optional[T] = None + ): + """HuggingFace generation algorithm. + + Args: + configuration: domain and application + specification, defining types and validations. + target: unused since it is not a conditional generator. + + Example: + An example for using a generative algorithm from HuggingFace:: + + configuration = HuggingFaceXLMGenerator() + algorithm = HuggingFaceGenerationAlgorithm(configuration=configuration) + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, + target=None, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Untargeted: + """Get the function to sample batches. + + Args: + configuration: helps to set up the application. + target: context or condition for the generation. Unused in the algorithm. + + Returns: + callable generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.sample + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceConfiguration(AlgorithmConfiguration[str, None]): + """Basic configuration for an hugging face algorithm.""" + + algorithm_type: ClassVar[str] = "generation" + domain: ClassVar[str] = "nlp" + + model_type: str = field( + default="", + metadata=dict( + description=f"Type of the model. Supported: {', '.join(MODEL_TYPES.keys())}" + ), + ) + prompt: str = field( + default="I'm a stochastic parrot.", + metadata=dict(description="Prompt for text generation."), + ) + length: int = field( + default=20, metadata=dict(description="Length of the generated text.") + ) + stop_token: str = field( + default="", metadata=dict(description="Stop token for text generation.") + ) + temperature: float = field( + default=1.0, + metadata=dict( + description="Temperature for sampling, the lower the greedier the sampling." + ), + ) + repetition_penalty: float = field( + default=1.0, + metadata=dict( + description="Primarily useful for CTRL model, where 1.2 should be used." + ), + ) + k: int = field( + default=50, + metadata=dict(description="Number of top-k probability tokens to keep."), + ) + p: float = field( + default=1.0, + metadata=dict( + description="Only tokens with cumulative probabilities summing up to this value are kept." + ), + ) + prefix: str = field( + default="", + metadata=dict( + description="Text defining context provided prior to the prompt." + ), + ) + number_of_sequences: int = field( + default=8, + metadata=dict(description="Number of text sequences to generate."), + ) + + def get_target_description(self) -> Optional[Dict[str, str]]: + """Get description of the target for generation. + + Returns: + target description, returns None in case no target is used. + """ + return None + + def get_conditional_generator(self, resources_path: str, **kwargs) -> Generator: + return Generator( + resources_path=resources_path, + model_type=self.model_type, + model_name=self.algorithm_version, + prompt=self.prompt, + length=self.length, + stop_token=self.stop_token, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + k=self.k, + p=self.p, + prefix=self.prefix, + number_of_sequences=self.number_of_sequences, + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceXLMGenerator(HuggingFaceConfiguration): + """Configuration to generate text using XLM.""" + + algorithm_version: str = "xlm-mlm-en-2048" + model_type: str = "xlm" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceCTRLGenerator(HuggingFaceConfiguration): + """Configuration to generate text using CTRL.""" + + algorithm_version: str = "ctrl" + model_type: str = "ctrl" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceGPT2Generator(HuggingFaceConfiguration): + """Configuration to generate text using GPT2.""" + + algorithm_version: str = "gpt2" + model_type: str = "gpt2" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceOpenAIGPTGenerator(HuggingFaceConfiguration): + """Configuration to generate text using OpenAIGPT.""" + + algorithm_version: str = "openai-gpt" + model_type: str = "openai-gpt" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceXLNetGenerator(HuggingFaceConfiguration): + """Configuration to generate text using XLNet.""" + + algorithm_version: str = "xlnet-large-cased" + model_type: str = "xlnet" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) + + +@ApplicationsRegistry.register_algorithm_application(HuggingFaceGenerationAlgorithm) +class HuggingFaceTransfoXLGenerator(HuggingFaceConfiguration): + """Configuration to generate text using TransfoXL.""" + + algorithm_version: str = "transfo-xl-wt103" + model_type: str = "transfo-xl" + + @classmethod + def list_versions(cls) -> Set[str]: + """Get possible algorithm versions. + + Standard S3 and cache search adding the version used in the configuration. + + Returns: + viable values as :attr:`algorithm_version` for the environment. + """ + logger.warning( + "more algorithm versions can be found on https://huggingface.co/models" + ) + return ( + get_configuration_class_with_attributes(cls) + .list_versions() + .union({cls.algorithm_version}) + ) diff --git a/src/gt4sd/algorithms/generation/hugging_face/implementation.py b/src/gt4sd/algorithms/generation/hugging_face/implementation.py new file mode 100644 index 000000000..1455c5426 --- /dev/null +++ b/src/gt4sd/algorithms/generation/hugging_face/implementation.py @@ -0,0 +1,270 @@ +""" +Implementation details for HuggingFace generation algorithms. + +Parts of the implementation inspired by: https://github.com/huggingface/transformers/blob/v4.2.1/examples/text-generation/run_generation.py. +""" + +import logging +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from transformers import ( + BasicTokenizer, + CTRLLMHeadModel, + CTRLTokenizer, + GPT2LMHeadModel, + GPT2Tokenizer, + OpenAIGPTLMHeadModel, + OpenAIGPTTokenizer, + TransfoXLLMHeadModel, + TransfoXLTokenizer, + XLMTokenizer, + XLMWithLMHeadModel, + XLNetLMHeadModel, + XLNetTokenizer, +) + +from ....frameworks.torch import device_claim + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +MAXIMUM_LENGTH = int(10000) +# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia +# in https://github.com/rusiaaman/XLNet-gen#methodology +# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e +PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family +(except for Alexei and Maria) are discovered. +The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the +remainder of the story. 1883 Western Siberia, +a young Grigori Rasputin is asked by his father and a group of men to perform magic. +Rasputin has a vision and denounces one of the men as a horse thief. Although his +father initially slaps him for making such an accusation, Rasputin watches as the +man is chased outside and beaten. Twenty years later, Rasputin sees a vision of +the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, +with people, even a bishop, begging for his blessing. """ + + +def set_seed(seed: int = 42) -> None: + """Set seed for all random number generators. + + Args: + seed: seed to set. Defaults to 42. + """ + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available: + torch.cuda.manual_seed_all(seed) # type:ignore + + +def prepare_ctrl_input(tokenizer: BasicTokenizer, prompt: str, **kwargs): + if kwargs.get("temperature", 1.0) > 0.7: + logger.warning( + "CTRL typically works better with lower temperatures (and lower k)." + ) + + encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False) # type:ignore + if not any( + encoded_prompt[0] == x for x in tokenizer.control_codes.values() # type:ignore + ): + logger.warning( + "not starting generation from a control code so you will not get good results" + ) + return prompt + + +def prepare_prefix_input(tokenizer: BasicTokenizer, prompt: str, **kwargs): + prefix = kwargs["prefix"] if kwargs.get("prefix", "") else PREFIX + prompt = prefix + prompt + return prompt + + +def adjust_length_to_model(length: int, maximum_sequence_length: int): + """Adjust sequence length. + + Args: + length: target length. + maximum_sequence_length: maximum sequence length. + + Returns: + the adjusted length. + """ + if length < 0 and maximum_sequence_length > 0: + logger.warning( + f"negative length, adjusting to model supported length {maximum_sequence_length}" + ) + length = maximum_sequence_length + elif 0 < maximum_sequence_length < length: + logger.warning( + f"longer then model supported length, adjusting to {maximum_sequence_length}" + ) + length = maximum_sequence_length + elif length < 0: + logger.warning(f"negative length, adjusting to maximal length {MAXIMUM_LENGTH}") + length = MAXIMUM_LENGTH + return length + + +MODEL_TYPES = { + "gpt2": (GPT2LMHeadModel, GPT2Tokenizer, None), + "ctrl": (CTRLLMHeadModel, CTRLTokenizer, prepare_ctrl_input), + "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, None), + "xlnet": (XLNetLMHeadModel, XLNetTokenizer, prepare_prefix_input), + "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer, prepare_prefix_input), + "xlm": (XLMWithLMHeadModel, XLMTokenizer, None), +} + + +class Generator: + """Implementation of a generator.""" + + def __init__( + self, + resources_path: str, + model_type: str, + model_name: str, + prompt: str, + length: int, + stop_token: str, + temperature: float, + repetition_penalty: float, + k: int, + p: float, + prefix: str, + number_of_sequences: int, + device: Optional[Union[torch.device, str]] = None, + ): + """An HuggingFace generation algorithm. + + Args: + resources_path: path to the cache. + model_type: type of the model. + model_name: name of the model weights/version. + prompt: prompt for text generation. + length: length of the generated text. + stop_token: stop token for text generation. + temperature: temperature for sampling, the lower the greedier the sampling. + repetition_penalty: primarily useful for CTRL model, where 1.2 should be used. + k: number of top-k probability token to keep. + p: only tokens with cumulative probabilities summing up to this value are kept. + prefix: text defining context provided prior to the prompt. + number_of_sequences: number of generated sequences. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + self.device = device_claim(device) + self.resources_path = resources_path + self.model_type = model_type + self.model_name = model_name + self.prompt = prompt + self.length = length + self.stop_token = None if stop_token == "" else stop_token + self.temperature = temperature + self.repetition_penalty = repetition_penalty + self.k = k + self.p = p + self.prefix = prefix + self.number_of_sequences = number_of_sequences + self.load_model() + + def load_model(self) -> None: + """Load a pretrained HuggingFace generation model.""" + try: + model_class, tokenizer_class, preprocessing_function = MODEL_TYPES[ + self.model_type + ] + except KeyError: + raise KeyError(f"model type: {self.model_type} not supported") + if ( + os.path.exists(self.resources_path) + and len(os.listdir(self.resources_path)) > 0 + ): + model_name_or_path = self.resources_path + else: + model_name_or_path = self.model_name + self.preprocessing_function = preprocessing_function + self.tokenizer = tokenizer_class.from_pretrained( # type:ignore + model_name_or_path + ) + self.model = model_class.from_pretrained(model_name_or_path) + self.model.to(self.device) + # adjusting length + self.length = adjust_length_to_model( + self.length, self.model.config.max_position_embeddings + ) + + def sample(self) -> List[str]: + """Sample text snippets. + + Returns: + generated text snippets. + """ + if self.preprocessing_function is not None: + preprocessed_prompt_text = self.preprocessing_function( + self.tokenizer, + self.prompt, + prefix=self.prefix, + temperature=self.temperature, + ) + + if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + tokenizer_kwargs = {"add_space_before_punct_symbol": True} + else: + tokenizer_kwargs = {} + + encoded_prompt = self.tokenizer.encode( + preprocessed_prompt_text, + add_special_tokens=False, + return_tensors="pt", + **tokenizer_kwargs, + ) + else: + encoded_prompt = self.tokenizer.encode( + self.prefix + self.prompt, add_special_tokens=False, return_tensors="pt" + ) + + encoded_prompt = encoded_prompt.to(self.device) + + if encoded_prompt.size()[-1] == 0: + input_ids = None + else: + input_ids = encoded_prompt + + output_sequences = self.model.generate( + input_ids=input_ids, + max_length=self.length + len(encoded_prompt[0]), + temperature=self.temperature, + top_k=self.k, + top_p=self.p, + repetition_penalty=self.repetition_penalty, + do_sample=True, + num_return_sequences=self.number_of_sequences, + ) + + # NOTE: remove the batch dimension when returning multiple sequences + if len(output_sequences.shape) > 2: + output_sequences.squeeze_() + + generated_sequences: List[str] = [] + + for generated_sequence in output_sequences: + generated_sequence = generated_sequence.tolist() + text = self.tokenizer.decode( + generated_sequence, clean_up_tokenization_spaces=True + ) + text = text[: text.find(self.stop_token) if self.stop_token else None] + total_sequence = ( + self.prompt + + text[ + len( + self.tokenizer.decode( + encoded_prompt[0], clean_up_tokenization_spaces=True + ) + ) : + ] + ) + generated_sequences.append(total_sequence) + + return generated_sequences diff --git a/src/gt4sd/algorithms/generation/molgx/__init__.py b/src/gt4sd/algorithms/generation/molgx/__init__.py new file mode 100644 index 000000000..89c993f3e --- /dev/null +++ b/src/gt4sd/algorithms/generation/molgx/__init__.py @@ -0,0 +1,17 @@ +"""MolGX initialization.""" +import logging + +from ....extras import EXTRAS_ENABLED + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from .core import MolGX, MolGXQM9Generator + + __all__ = [ + "MolGX", + "MolGXQM9Generator", + ] +else: + logger.warning("install AMD_analytcs extras to use MolGX") diff --git a/src/gt4sd/algorithms/generation/molgx/core.py b/src/gt4sd/algorithms/generation/molgx/core.py new file mode 100644 index 000000000..89c8a0886 --- /dev/null +++ b/src/gt4sd/algorithms/generation/molgx/core.py @@ -0,0 +1,227 @@ +"""MolGX Algorithm. + +MolGX generation algorithm. +""" + +import logging +from dataclasses import field +from typing import Any, ClassVar, Dict, Iterator, Optional, TypeVar + +from ....extras import EXTRAS_ENABLED + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from ....domains.materials import SmallMolecule, validate_molecules + from ....exceptions import InvalidItem + from ...core import AlgorithmConfiguration, GeneratorAlgorithm, Untargeted + from ...registry import ApplicationsRegistry + from .implementation import MolGXGenerator + + T = type(None) + S = TypeVar("S", bound=SmallMolecule) + + class MolGX(GeneratorAlgorithm[S, T]): + """MolGX Algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T] = None, + ): + """Instantiate MolGX ready to generate items. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for generating small molecules (SMILES) with given HOMO and LUMO energies: + + configuration = MolGXQM9Generator() + molgx = MolGX(configuration=configuration, target=target) + items = list(molgx.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Untargeted: + """Get the function to sample batches via the ConditionalGenerator. + + Args: + configuration: helps to set up the application. + target: context or condition for the generation. Unused in the algorithm. + + Returns: + callable generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: MolGXGenerator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.generate + + def validate_configuration( + self, configuration: AlgorithmConfiguration[S, T] + ) -> AlgorithmConfiguration[S, T]: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + def sample(self, number_of_items: int = 100) -> Iterator[S]: + """Generate a number of unique and valid items. + + Args: + number_of_items: number of items to generate. + Defaults to 100. + + Yields: + the items. + """ + if hasattr(self.configuration, "maximum_number_of_solutions"): + maxiumum_number_of_molecules = int( + getattr(self.configuration, "maximum_number_of_solutions") + ) + if number_of_items > maxiumum_number_of_molecules: + logger.warning( + f"current MolGX configuration can not support generation of {number_of_items} molecules..." + ) + logger.warning( + f"to enable generation of {number_of_items} molecules, increase 'maximum_number_of_solutions' (currently set to {maxiumum_number_of_molecules})" + ) + number_of_items = maxiumum_number_of_molecules + logger.warning( + f"generating at most: {maxiumum_number_of_molecules}..." + ) + return super().sample(number_of_items=number_of_items) + + @ApplicationsRegistry.register_algorithm_application(MolGX) + class MolGXQM9Generator(AlgorithmConfiguration[SmallMolecule, Any]): + """Configuration to generate compounds with given HOMO and LUMO energies.""" + + algorithm_type: ClassVar[str] = "generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + homo_energy_value: float = field( + default=-0.25, + metadata=dict(description="Target HOMO energy value."), + ) + lumo_energy_value: float = field( + default=0.08, + metadata=dict(description="Target LUMO energy value."), + ) + use_linear_model: bool = field( + default=True, + metadata=dict(description="Linear model usage."), + ) + number_of_candidates: int = field( + default=2, + metadata=dict(description="Number of candidates to consider."), + ) + maximum_number_of_candidates: int = field( + default=5, + metadata=dict(description="Maximum number of candidates to consider."), + ) + maximum_number_of_solutions: int = field( + default=10, + metadata=dict(description="Maximum number of solutions."), + ) + maximum_number_of_nodes: int = field( + default=50000, + metadata=dict( + description="Maximum number of nodes in the graph exploration." + ), + ) + beam_size: int = field( + default=2000, + metadata=dict(description="Size of the beam during search."), + ) + without_estimate: bool = field( + default=True, + metadata=dict(description="Disable estimates."), + ) + use_specific_rings: bool = field( + default=True, + metadata=dict( + description="Flag to indicate whether specific rings are used." + ), + ) + use_fragment_const: bool = field( + default=False, + metadata=dict(description="Using constant fragments."), + ) + + def get_target_description(self) -> Optional[Dict[str, str]]: + """Get description of the target for generation. + + Returns: + target description, returns None in case no target is used. + """ + return None + + def get_conditional_generator(self, resources_path: str) -> MolGXGenerator: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate` for generation. + """ + return MolGXGenerator( + resources_path=resources_path, + homo_energy_value=self.homo_energy_value, + lumo_energy_value=self.lumo_energy_value, + use_linear_model=self.use_linear_model, + number_of_candidates=self.number_of_candidates, + maximum_number_of_candidates=self.maximum_number_of_candidates, + maximum_number_of_solutions=self.maximum_number_of_solutions, + maximum_number_of_nodes=self.maximum_number_of_nodes, + beam_size=self.beam_size, + without_estimate=self.without_estimate, + use_specific_rings=self.use_specific_rings, + use_fragment_const=self.use_fragment_const, + tag_name="qm9", + ) + + def validate_item(self, item: str) -> SmallMolecule: + """Check that item is a valid SMILES. + + Args: + item: a generated item that is possibly not valid. + + Raises: + InvalidItem: in case the item can not be validated. + + Returns: + the validated SMILES. + """ + ( + molecules, + _, + ) = validate_molecules([item]) + if molecules[0] is None: + raise InvalidItem( + title="InvalidSMILES", + detail=f'rdkit.Chem.MolFromSmiles returned None for "{item}"', + ) + return SmallMolecule(item) + + +else: + logger.warning("install AMD_analytcs extras to use MolGX") diff --git a/src/gt4sd/algorithms/generation/molgx/implementation.py b/src/gt4sd/algorithms/generation/molgx/implementation.py new file mode 100644 index 000000000..a4d62f887 --- /dev/null +++ b/src/gt4sd/algorithms/generation/molgx/implementation.py @@ -0,0 +1,239 @@ +"""Implementation of MolGX conditional generators.""" + +import logging +import os +from typing import Any, Dict, List + +from ....extras import EXTRAS_ENABLED + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +if EXTRAS_ENABLED: + from AMD_Analytics.amdsdk import AMDsdk + + class MolGXGenerator: + """Interface for MolGX generator.""" + + def __init__( + self, + resources_path: str, + tag_name: str, + homo_energy_value: float = -0.25, + lumo_energy_value: float = 0.08, + use_linear_model: bool = True, + number_of_candidates: int = 2, + maximum_number_of_candidates: int = 3, + maximum_number_of_solutions: int = 3, + maximum_number_of_nodes: int = 50000, + beam_size: int = 2000, + without_estimate: bool = True, + use_specific_rings: bool = True, + use_fragment_const: bool = False, + ) -> None: + """Instantiate a MolGX generator. + + Args: + resources_path: path to the resources for model loading. + tag_name: tag for the pretrained model. + homo_energy_value: target HOMO energy value. Defaults to -0.25. + lumo_energy_value: target LUMO energy value. Defaults to 0.08. + use_linear_model: linear model usage. Defaults to True. + number_of_candidates: number of candidates to consider. Defaults to 2. + maximum_number_of_candidates: maximum number of candidates to consider. Defaults to 3. + maximum_number_of_solutions: maximum number of solutions. Defaults to 3. + maximum_number_of_nodes: maximum number of nodes in the graph exploration. Defaults to 50000. + beam_size: size of the beam during search. Defaults to 2000. + without_estimate: disable estimates. Defaults to True. + use_specific_rings: flag to indicate whether specific rings are used. Defaults to True. + use_fragment_const: using constant fragments. Defaults to False. + + Raises: + RuntimeError: in the case extras are disabled. + """ + if not EXTRAS_ENABLED: + raise RuntimeError("Can't instantiate MolGXGenerator, extras disabled!") + + # loading artifacts + self.resources_path = resources_path + self.tag_name = tag_name + self.amd = self.load_molgx(self.resources_path, self.tag_name) + self.molecules_data, self.target_property = self.amd.LoadPickle("model") + # algorithm parameters + self._homo_energy_value = homo_energy_value + self._lumo_energy_value = lumo_energy_value + self._use_linear_model = use_linear_model + self._number_of_candidates = number_of_candidates + self._maximum_number_of_candidates = maximum_number_of_candidates + self._maximum_number_of_solutions = maximum_number_of_solutions + self._maximum_number_of_nodes = maximum_number_of_nodes + self._beam_size = beam_size + self._without_estimate = without_estimate + self._use_specific_rings = use_specific_rings + self._use_fragment_const = use_fragment_const + self._parameters = self._create_parameters_dictionary() + + @staticmethod + def load_molgx(resource_path: str, tag_name: str) -> AMDsdk: + """Load MolGX model. + + Args: + resource_path: path to the resources for model loading. + tag_name: tag for the pretrained model. + + Returns: + MolGX model SDK. + """ + return AMDsdk( + dir_pickle=os.path.join(resource_path, "pickle"), + dir_data=os.path.join(resource_path, "data"), + tag_data=tag_name, + ) + + def _create_parameters_dictionary(self) -> Dict[str, Any]: + """Create parameters dictionary. + + Returns: + the parameters to run MolGX. + """ + self.target_property["homo"] = (self.homo_energy_value,) * 2 + self.target_property["lumo"] = (self.lumo_energy_value,) * 2 + parameters: Dict[str, Any] = {} + parameters["target_property"] = self.target_property + parameters["use_linear_model"] = self.use_linear_model + parameters["num_candidate"] = self.number_of_candidates + parameters["max_candidate"] = self.maximum_number_of_candidates + parameters["max_solution"] = self.maximum_number_of_solutions + parameters["max_node"] = self.maximum_number_of_nodes + parameters["beam_size"] = self.beam_size + parameters["without_estimate"] = self.without_estimate + parameters["use_specific_rings"] = self.use_specific_rings + parameters["use_fragment_const"] = self.use_fragment_const + return parameters + + @property + def homo_energy_value(self) -> float: + return self._homo_energy_value + + @homo_energy_value.setter + def homo_energy_value(self, value: float) -> None: + self._homo_energy_value = value + self.parameters = self._create_parameters_dictionary() + + @property + def lumo_energy_value(self) -> float: + return self._lumo_energy_value + + @lumo_energy_value.setter + def lumo_energy_value(self, value: float) -> None: + self._lumo_energy_value = value + self.parameters = self._create_parameters_dictionary() + + @property + def use_linear_model(self) -> bool: + return self._use_linear_model + + @use_linear_model.setter + def use_linear_model(self, value: bool) -> None: + self._use_linear_model = value + self.parameters = self._create_parameters_dictionary() + + @property + def number_of_candidates(self) -> int: + return self._number_of_candidates + + @number_of_candidates.setter + def number_of_candidates(self, value: int) -> None: + self._number_of_candidates = value + self.parameters = self._create_parameters_dictionary() + + @property + def maximum_number_of_candidates(self) -> int: + return self._maximum_number_of_candidates + + @maximum_number_of_candidates.setter + def maximum_number_of_candidates(self, value: int) -> None: + self._maximum_number_of_candidates = value + self.parameters = self._create_parameters_dictionary() + + @property + def maximum_number_of_solutions(self) -> int: + return self._maximum_number_of_solutions + + @maximum_number_of_solutions.setter + def maximum_number_of_solutions(self, value: int) -> None: + self._maximum_number_of_solutions = value + self.parameters = self._create_parameters_dictionary() + + @property + def maximum_number_of_nodes(self) -> int: + return self._maximum_number_of_nodes + + @maximum_number_of_nodes.setter + def maximum_number_of_nodes(self, value: int) -> None: + self._maximum_number_of_nodes = value + self.parameters = self._create_parameters_dictionary() + + @property + def beam_size(self) -> int: + return self._beam_size + + @beam_size.setter + def beam_size(self, value: int) -> None: + self._beam_size = value + self.parameters = self._create_parameters_dictionary() + + @property + def without_estimate(self) -> bool: + return self._without_estimate + + @without_estimate.setter + def without_estimate(self, value: bool) -> None: + self._without_estimate = value + self.parameters = self._create_parameters_dictionary() + + @property + def use_specific_rings(self) -> bool: + return self._use_specific_rings + + @use_specific_rings.setter + def use_specific_rings(self, value: bool) -> None: + self._use_specific_rings = value + self.parameters = self._create_parameters_dictionary() + + @property + def use_fragment_const(self) -> bool: + return self._use_fragment_const + + @use_fragment_const.setter + def use_fragment_const(self, value: bool) -> None: + self._use_fragment_const = value + self.parameters = self._create_parameters_dictionary() + + @property + def parameters(self) -> Dict[str, Any]: + return self._parameters + + @parameters.setter + def parameters(self, value: Dict[str, Any]) -> None: + parameters = self._create_parameters_dictionary() + parameters.update(value) + self._parameters = parameters + + def generate(self) -> List[str]: + """Sample random molecules. + + Returns: + sampled molecule (SMILES). + """ + # generate molecules + logger.info( + f"running MolGX with the following parameters: {self.parameters}" + ) + molecules_df = self.amd.GenMols(self.molecules_data, self.parameters) + logger.info("MolGX run completed") + return molecules_df["SMILES"].tolist() + + +else: + logger.warning("install AMD_analytcs extras to use MolGX") diff --git a/src/gt4sd/algorithms/generation/polymer_blocks/__init__.py b/src/gt4sd/algorithms/generation/polymer_blocks/__init__.py new file mode 100644 index 000000000..862c08d1d --- /dev/null +++ b/src/gt4sd/algorithms/generation/polymer_blocks/__init__.py @@ -0,0 +1,5 @@ +"""PolymerBlocks initialization.""" + +from .core import PolymerBlocks, PolymerBlocksGenerator + +__all__ = ["PolymerBlocks", "PolymerBlocksGenerator"] diff --git a/src/gt4sd/algorithms/generation/polymer_blocks/core.py b/src/gt4sd/algorithms/generation/polymer_blocks/core.py new file mode 100644 index 000000000..5ad1f8c00 --- /dev/null +++ b/src/gt4sd/algorithms/generation/polymer_blocks/core.py @@ -0,0 +1,122 @@ +"""PaccMann vanilla generator trained on polymer building blocks (catalysts/monomers).""" + +import logging +from dataclasses import field +from typing import ClassVar, Dict, Optional, TypeVar + +from ....domains.materials import SmallMolecule, validate_molecules +from ....exceptions import InvalidItem +from ...core import AlgorithmConfiguration, GeneratorAlgorithm, Untargeted +from ...registry import ApplicationsRegistry +from .implementation import Generator + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = type(None) +S = TypeVar("S", bound=SmallMolecule) + + +class PolymerBlocks(GeneratorAlgorithm[S, T]): + def __init__( + self, configuration: AlgorithmConfiguration, target: Optional[T] = None + ): + """Polymer blocks generation. + + Args: + configuration: domain and application + specification, defining types and validations. + target: unused since it is not a conditional generator. + + Example: + An example for generating small molecules (SMILES) that resembles + monomers/catalysts for polymer synthesis:: + + configuration = PolymerBlocksGenerator() + polymer_blocks = PolymerBlocks(configuration=configuration) + items = list(polymer_blocks.sample(10)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, + target=None, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Untargeted: + """Get the function to sample batches via the Generator. + + Args: + configuration: helps to set up the application. + target: context or condition for the generation. Unused in the algorithm. + + Returns: + callable generating a batch of items. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: Generator = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.sample + + def validate_configuration( + self, configuration: AlgorithmConfiguration + ) -> AlgorithmConfiguration: + # TODO raise InvalidAlgorithmConfiguration + assert isinstance(configuration, AlgorithmConfiguration) + return configuration + + +@ApplicationsRegistry.register_algorithm_application(PolymerBlocks) +class PolymerBlocksGenerator(AlgorithmConfiguration[SmallMolecule, None]): + """Configuration to generate subunits of polymers.""" + + algorithm_type: ClassVar[str] = "generation" + domain: ClassVar[str] = "materials" + algorithm_version: str = "v0" + + batch_size: int = field( + default=32, + metadata=dict(description="Batch size used for the generative model sampling."), + ) + generated_length: int = field( + default=100, + metadata=dict( + description="Maximum length in tokens of the generated molcules (relates to the SMILES length)." + ), + ) + + def get_target_description(self) -> Optional[Dict[str, str]]: + """Get description of the target for generation. + + Returns: + target description, returns None in case no target is used. + """ + return None + + def get_conditional_generator(self, resources_path: str) -> Generator: + return Generator( + resources_path=resources_path, + generated_length=self.generated_length, + batch_size=self.batch_size, + ) + + def validate_item(self, item: str) -> SmallMolecule: + ( + molecules, + _, + ) = validate_molecules([item]) + if molecules[0] is None: + raise InvalidItem( + title="InvalidSMILES", + detail=f'rdkit.Chem.MolFromSmiles returned None for "{item}"', + ) + return SmallMolecule(item) diff --git a/src/gt4sd/algorithms/generation/polymer_blocks/implementation.py b/src/gt4sd/algorithms/generation/polymer_blocks/implementation.py new file mode 100644 index 000000000..744d9ae8c --- /dev/null +++ b/src/gt4sd/algorithms/generation/polymer_blocks/implementation.py @@ -0,0 +1,126 @@ +"""Implementation details for PaccMann vanilla generator trained on polymer building blocks (catalysts/monomers).""" + +import json +import os +from typing import List, Optional, Union + +import torch +from rdkit import Chem, RDLogger +from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE +from paccmann_chemistry.utils import get_device +from paccmann_chemistry.utils.search import SamplingSearch +from pytoda.smiles.smiles_language import SMILESLanguage +from pytoda.smiles.transforms import Selfies, SMILESToTokenIndexes +from pytoda.transforms import Compose, ToTensor + +from ....frameworks.torch import device_claim + +RDLogger.DisableLog("rdApp.*") + + +class Generator: + def __init__( + self, + resources_path: str, + generated_length: int = 100, + batch_size: int = 32, + device: Optional[Union[torch.device, str]] = None, + ): + """Initialize the encoder/decoder generative model. + + Args: + resources_path: directory where to find models and parameters. + generated_length: length of the generated molecule in tokens. Defaults to 100. + batch_size: size of the batch. Defaults to 1. + device: device where the inference is running either as a dedicated class or a string. + If not provided is inferred. + """ + self.device = device_claim(device) + self.generated_length = generated_length + self.batch_size = batch_size + self.resources_path = resources_path + self.load_pretrained_paccmann( + os.path.join(self.resources_path, "params.json"), + os.path.join(self.resources_path, "smiles_language.pkl"), + os.path.join(self.resources_path, "weights.pt"), + self.batch_size, + ) + + def load_pretrained_paccmann( + self, params_file: str, lang_file: str, weights_file: str, batch_size: int + ) -> None: + """Load a pretrained PaccMann model. + + Args: + params_file: file for the parameters. + lang_file: language file. + weights_file: serialized weights file. + batch_size: size of the batch. + """ + params = dict() + with open(params_file, "r") as f: + params.update(json.load(f)) + params["batch_mode"] = "Padded" + params["batch_size"] = batch_size + + self.selfies = params.get("selfies", False) + + self.device = get_device() + self.smiles_language = SMILESLanguage.load(lang_file) + + self.gru_encoder = StackGRUEncoder(params).to(self.device) + self.gru_decoder = StackGRUDecoder(params).to(self.device) + self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder).to(self.device) + self.gru_vae.load_state_dict(torch.load(weights_file, map_location=self.device)) + self.gru_vae.eval() + + transforms = [] + if self.selfies: + transforms += [Selfies()] + transforms += [SMILESToTokenIndexes(smiles_language=self.smiles_language)] + transforms += [ToTensor(device=self.device)] + self.transform = Compose(transforms) + + def decode( + self, latent_z: torch.Tensor, search: SamplingSearch = SamplingSearch() + ) -> List[int]: + """Decodes a sequence of tokens given a position in the latent space. + + Args: + latent_z: a batch size x latent size tensor. + search: defaults to sampling multinomial search. + + Returns: + list of list of token indices. + """ + latent_z = latent_z.view(1, latent_z.shape[0], latent_z.shape[1]).float() + molecule_iter = self.gru_vae.generate( + latent_z, + prime_input=torch.tensor([self.smiles_language.start_index]).to( + self.device + ), + end_token=torch.tensor([self.smiles_language.stop_index]).to(self.device), + generate_len=self.generated_length, + search=search, + ) + return [ + [self.smiles_language.start_index] + m.cpu().detach().tolist() + for m in molecule_iter + ] + + def sample(self) -> List[str]: + """Sample random molecules. + + Returns: + sampled molecule (SMILES). + """ + mol: List[str] = [] + while len(mol) < 1: + indexes = self.decode( + torch.randn( + self.batch_size, self.gru_decoder.latent_dim, device=self.device + ) + ) + mol = [self.smiles_language.token_indexes_to_smiles(m) for m in indexes] + mol = [m for m in mol if Chem.MolFromSmiles(m) is not None] + return mol diff --git a/src/gt4sd/algorithms/generation/tests/__init__.py b/src/gt4sd/algorithms/generation/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/generation/tests/test_hugging_face.py b/src/gt4sd/algorithms/generation/tests/test_hugging_face.py new file mode 100644 index 000000000..5df708790 --- /dev/null +++ b/src/gt4sd/algorithms/generation/tests/test_hugging_face.py @@ -0,0 +1,193 @@ +"""HuggingFace tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.generation.hugging_face import ( + HuggingFaceCTRLGenerator, + HuggingFaceGenerationAlgorithm, + HuggingFaceGPT2Generator, + HuggingFaceOpenAIGPTGenerator, + HuggingFaceTransfoXLGenerator, + HuggingFaceXLMGenerator, + HuggingFaceXLNetGenerator, +) +from gt4sd.algorithms.registry import ApplicationsRegistry +from gt4sd.tests.utils import GT4SDTestSettings + +test_settings = GT4SDTestSettings.get_instance() + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + HuggingFaceXLMGenerator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceCTRLGenerator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceGPT2Generator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceOpenAIGPTGenerator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceXLNetGenerator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceTransfoXLGenerator, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [ + (HuggingFaceXLMGenerator), + (HuggingFaceCTRLGenerator), + (HuggingFaceGPT2Generator), + (HuggingFaceOpenAIGPTGenerator), + (HuggingFaceXLNetGenerator), + (HuggingFaceTransfoXLGenerator), + ], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [ + (HuggingFaceXLMGenerator), + (HuggingFaceCTRLGenerator), + (HuggingFaceGPT2Generator), + (HuggingFaceOpenAIGPTGenerator), + (HuggingFaceXLNetGenerator), + (HuggingFaceTransfoXLGenerator), + ], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert len(versions) > 0 + + +@pytest.mark.parametrize( + "config, algorithm", + [ + (HuggingFaceXLMGenerator, HuggingFaceGenerationAlgorithm), + pytest.param( + HuggingFaceCTRLGenerator, + HuggingFaceGenerationAlgorithm, + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + (HuggingFaceGPT2Generator, HuggingFaceGenerationAlgorithm), + (HuggingFaceOpenAIGPTGenerator, HuggingFaceGenerationAlgorithm), + (HuggingFaceXLNetGenerator, HuggingFaceGenerationAlgorithm), + (HuggingFaceTransfoXLGenerator, HuggingFaceGenerationAlgorithm), + ], +) +def test_generation_via_import(config, algorithm): + algorithm = algorithm(configuration=config(length=10, number_of_sequences=1)) + items = list(algorithm.sample(1)) + assert len(items) == 1 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + ( + HuggingFaceXLMGenerator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + pytest.param( + HuggingFaceCTRLGenerator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + ( + HuggingFaceGPT2Generator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceOpenAIGPTGenerator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceXLNetGenerator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ( + HuggingFaceTransfoXLGenerator.__name__, + "generation", + "nlp", + HuggingFaceGenerationAlgorithm.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target=None, + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + length=10, + number_of_sequences=1, + ) + items = list(algorithm.sample(1)) + assert len(items) == 1 diff --git a/src/gt4sd/algorithms/generation/tests/test_molgx.py b/src/gt4sd/algorithms/generation/tests/test_molgx.py new file mode 100644 index 000000000..2d203aa6a --- /dev/null +++ b/src/gt4sd/algorithms/generation/tests/test_molgx.py @@ -0,0 +1,102 @@ +"""MolGX tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.registry import ApplicationsRegistry +from gt4sd.extras import EXTRAS_ENABLED +from gt4sd.tests.utils import GT4SDTestSettings + +if not EXTRAS_ENABLED: + pytest.skip("Extras from custom PyPI disabled", allow_module_level=True) +else: + from gt4sd.algorithms.generation.molgx import MolGX, MolGXQM9Generator + +test_settings = GT4SDTestSettings.get_instance() + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [(MolGXQM9Generator, "generation", "materials", MolGX.__name__)], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [(MolGXQM9Generator)], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [(MolGXQM9Generator)], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm", + [ + pytest.param( + MolGXQM9Generator, + MolGX, + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ) + ], +) +def test_generation_via_import(config, algorithm): + algorithm = algorithm(configuration=config()) + items = list(algorithm.sample(3)) + assert len(items) == 3 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + pytest.param( + MolGXQM9Generator.__name__, + "generation", + "materials", + MolGX.__name__, + marks=pytest.mark.skipif(test_settings.gt4sd_ci, reason="slow_runtime"), + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target=None, + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + items = list(algorithm.sample(3)) + assert len(items) == 3 diff --git a/src/gt4sd/algorithms/generation/tests/test_polymer_blocks.py b/src/gt4sd/algorithms/generation/tests/test_polymer_blocks.py new file mode 100644 index 000000000..519a5a59d --- /dev/null +++ b/src/gt4sd/algorithms/generation/tests/test_polymer_blocks.py @@ -0,0 +1,103 @@ +"""PolymerBlocks tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.generation.polymer_blocks import ( + PolymerBlocks, + PolymerBlocksGenerator, +) +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + PolymerBlocksGenerator, + "generation", + "materials", + PolymerBlocks.__name__, + ) + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [(PolymerBlocksGenerator)], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [(PolymerBlocksGenerator)], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "v0" in versions + + +@pytest.mark.parametrize( + "config, algorithm", + [ + ( + PolymerBlocksGenerator, + PolymerBlocks, + ) + ], +) +def test_generation_via_import(config, algorithm): + algorithm = algorithm(configuration=config()) + items = list(algorithm.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "algorithm_application, algorithm_type, domain, algorithm_name", + [ + ( + PolymerBlocksGenerator.__name__, + "generation", + "materials", + PolymerBlocks.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target=None, + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + generated_length=5, + ) + items = list(algorithm.sample(5)) + assert len(items) == 5 diff --git a/src/gt4sd/algorithms/prediction/__init__.py b/src/gt4sd/algorithms/prediction/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/prediction/core.py b/src/gt4sd/algorithms/prediction/core.py new file mode 100644 index 000000000..f354139f8 --- /dev/null +++ b/src/gt4sd/algorithms/prediction/core.py @@ -0,0 +1,11 @@ +"""Property prediction algorithms.""" + +from typing import TypeVar + +# from ..core import PropertyPredictor +# from ...domains.materials import ConditionPAG + + +S = TypeVar("S") +# T = TypeVar('T') CLaSSVAE uses List[PropertyPredictor[S, U]] as target +U = TypeVar("U") diff --git a/src/gt4sd/algorithms/prediction/paccmann.py b/src/gt4sd/algorithms/prediction/paccmann.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/prediction/tests/__init__.py b/src/gt4sd/algorithms/prediction/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/prediction/tests/test_topics_zero_shot.py b/src/gt4sd/algorithms/prediction/tests/test_topics_zero_shot.py new file mode 100644 index 000000000..9070cebbd --- /dev/null +++ b/src/gt4sd/algorithms/prediction/tests/test_topics_zero_shot.py @@ -0,0 +1,101 @@ +"""TopicsZeroShot tests.""" + +from typing import ClassVar, Type + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.algorithms.prediction.topics_zero_shot import TopicsPredictor, TopicsZeroShot +from gt4sd.algorithms.registry import ApplicationsRegistry + + +def get_classvar_type(class_var): + """Extract type from ClassVar type annotation: `ClassVar[T]] -> T`.""" + return class_var.__args__[0] + + +@pytest.mark.parametrize( + "config_class, algorithm_type, domain, algorithm_name", + [ + ( + TopicsPredictor, + "prediction", + "nlp", + TopicsZeroShot.__name__, + ) + ], +) +def test_config_class( + config_class: Type[AlgorithmConfiguration], + algorithm_type: str, + domain: str, + algorithm_name: str, +): + assert config_class.algorithm_type == algorithm_type + assert config_class.domain == domain + assert config_class.algorithm_name == algorithm_name + + for keyword, type_annotation in config_class.__annotations__.items(): + if keyword in ("algorithm_type", "domain", "algorithm_name"): + assert type_annotation.__origin__ is ClassVar # type: ignore + assert str == get_classvar_type(type_annotation) + + +@pytest.mark.parametrize( + "config_class", + [(TopicsPredictor)], +) +def test_config_instance(config_class: Type[AlgorithmConfiguration]): + config = config_class() # type:ignore + assert config.algorithm_application == config_class.__name__ + + +@pytest.mark.parametrize( + "config_class", + [(TopicsPredictor)], +) +def test_available_versions(config_class: Type[AlgorithmConfiguration]): + versions = config_class.list_versions() + assert "dbpedia" in versions + + +@pytest.mark.parametrize( + "config, target, algorithm", + [ + ( + TopicsPredictor, + "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.", + TopicsZeroShot, + ) + ], +) +def test_generation_via_import(config, target, algorithm): + algorithm = algorithm(configuration=config(), target=target) + items = list(algorithm.sample(5)) + assert len(items) == 5 + + +@pytest.mark.parametrize( + "algorithm_application, target, algorithm_type, domain, algorithm_name", + [ + ( + TopicsPredictor.__name__, + "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.", + "prediction", + "nlp", + TopicsZeroShot.__name__, + ), + ], +) +def test_generation_via_registry( + algorithm_type, target, domain, algorithm_name, algorithm_application +): + algorithm = ApplicationsRegistry.get_application_instance( + target=target, + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + items = list(algorithm.sample(5)) + assert len(items) == 5 diff --git a/src/gt4sd/algorithms/prediction/topics_zero_shot/__init__.py b/src/gt4sd/algorithms/prediction/topics_zero_shot/__init__.py new file mode 100644 index 000000000..6b81953f9 --- /dev/null +++ b/src/gt4sd/algorithms/prediction/topics_zero_shot/__init__.py @@ -0,0 +1,5 @@ +"""Topics modelling with zero-shot learning initialization.""" + +from .core import TopicsPredictor, TopicsZeroShot + +__all__ = ["TopicsZeroShot", "TopicsPredictor"] diff --git a/src/gt4sd/algorithms/prediction/topics_zero_shot/core.py b/src/gt4sd/algorithms/prediction/topics_zero_shot/core.py new file mode 100644 index 000000000..21c19c428 --- /dev/null +++ b/src/gt4sd/algorithms/prediction/topics_zero_shot/core.py @@ -0,0 +1,111 @@ +"""Algortihms for topic modelling using zero-shot learning via MLNI models.""" + +import logging +from dataclasses import field +from typing import Any, Callable, ClassVar, Dict, Iterable, Optional, TypeVar + +from ...core import AlgorithmConfiguration, GeneratorAlgorithm +from ...registry import ApplicationsRegistry +from .implementation import ZeroShotClassifier + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +T = TypeVar("T", bound=Any) +S = TypeVar("S", bound=Any) +Targeted = Callable[[T], Iterable[Any]] + + +class TopicsZeroShot(GeneratorAlgorithm[S, T]): + """Topics prediction algorithm.""" + + def __init__( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ): + """Instantiate TopicsZeroShot ready to predict topics. + + Args: + configuration: domain and application + specification defining parameters, types and validations. + target: a target for which to generate items. + + Example: + An example for predicting topics for a given text:: + + config = TopicsPredictor() + algorithm = TopicsZeroShot(configuration=config, target="This is a text I want to understand better") + items = list(algorithm.sample(1)) + print(items) + """ + + configuration = self.validate_configuration(configuration) + # TODO there might also be a validation/check on the target input + + super().__init__( + configuration=configuration, # type:ignore + target=target, # type:ignore + ) + + def get_generator( + self, + configuration: AlgorithmConfiguration[S, T], + target: Optional[T], + ) -> Targeted[T]: + """Get the function to perform the prediction via TopicsZeroShot's generator. + + Args: + configuration: helps to set up specific application of TopicsZeroShot. + target: context or condition for the generation. + + Returns: + callable with target predicting topics sorted by relevance. + """ + logger.info("ensure artifacts for the application are present.") + self.local_artifacts = configuration.ensure_artifacts() + implementation: ZeroShotClassifier = configuration.get_conditional_generator( # type: ignore + self.local_artifacts + ) + return implementation.predict + + +@ApplicationsRegistry.register_algorithm_application(TopicsZeroShot) +class TopicsPredictor(AlgorithmConfiguration[str, str]): + """Configuration to generate topics.""" + + algorithm_type: ClassVar[str] = "prediction" + domain: ClassVar[str] = "nlp" + algorithm_version: str = "dbpedia" + + model_name: str = field( + default="facebook/bart-large-mnli", + metadata=dict( + description="MLNI model name to use. If the model is not found in the cache, a download from HuggingFace will be attempted." + ), + ) + + def get_target_description(self) -> Dict[str, str]: + """Get description of the target for generation. + + Returns: + target description. + """ + return { + "title": "Text to analyze", + "description": "Text considered for the topics prediction task.", + "type": "string", + } + + def get_conditional_generator(self, resources_path: str) -> ZeroShotClassifier: + """Instantiate the actual generator implementation. + + Args: + resources_path: local path to model files. + + Returns: + instance with :meth:`generate_batch` method for targeted generation. + """ + return ZeroShotClassifier( + resources_path=resources_path, model_name=self.model_name + ) diff --git a/src/gt4sd/algorithms/prediction/topics_zero_shot/implementation.py b/src/gt4sd/algorithms/prediction/topics_zero_shot/implementation.py new file mode 100644 index 000000000..11c087c49 --- /dev/null +++ b/src/gt4sd/algorithms/prediction/topics_zero_shot/implementation.py @@ -0,0 +1,79 @@ +"""Implementation of the zero-shot classifier.""" + +import json +import logging +import os +from typing import List, Optional, Union + +import torch +from transformers import pipeline + +from ....frameworks.torch import device_claim + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class ZeroShotClassifier: + """ + Zero-shot classifier based on the HuggingFace pipeline leveraging MLNI. + """ + + def __init__( + self, + resources_path: str, + model_name: str, + device: Optional[Union[torch.device, str]] = None, + ): + """Initialize ZeroShotClassifier. + + Args: + resources_path: path where to load hypothesis, candidate labels and, optionally, the model. + model_name: name of the model to load from the cache or download from HuggingFace. + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + device = device_claim(device) + self.device = -1 if device.type == "cpu" else int(device.type.split(":")[1]) + self.resources_path = resources_path + self.model_name = model_name + self.load_pipeline() + + def load_pipeline(self) -> None: + """Load zero shot classification MLNI pipeline.""" + metadata_filepath = os.path.join(self.resources_path, "metadata.json") + if os.path.exists(metadata_filepath): + with open(metadata_filepath) as fp: + metadata = json.load(fp) + self.labels = metadata["labels"] + self.hypothesis_template = metadata["hypothesis_template"] + self.model_name_or_path = os.path.join(self.resources_path, self.model_name) + if not os.path.exists(self.model_name_or_path): + logger.info( + f"no model named {self.model_name_or_path} in cache, using HuggingFace" + ) + self.model_name_or_path = self.model_name + else: + message = f"could not retrieve the MLNI pipeline from the cache: {metadata_filepath} does not exists!" + logger.error(message) + raise ValueError(message) + self.model = pipeline( + "zero-shot-classification", + model=self.model_name_or_path, + device=self.device, + ) + + def predict(self, text: str) -> List[str]: + """Get sorted classification labels. + + Args: + text: text to classify. + + Returns: + labels sorted by score from highest to lowest. + """ + return self.model( + text, + candidate_labels=self.labels, + hypothesis_template=self.hypothesis_template, + )["labels"] diff --git a/src/gt4sd/algorithms/registry.py b/src/gt4sd/algorithms/registry.py new file mode 100644 index 000000000..80598f7ed --- /dev/null +++ b/src/gt4sd/algorithms/registry.py @@ -0,0 +1,348 @@ +"""Collection of available methods.""" + + +import logging +from dataclasses import dataclass as vanilla_dataclass +from dataclasses import field, make_dataclass +from functools import WRAPPER_ASSIGNMENTS, update_wrapper +from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Type + +import pydantic + +# pyright (pylance in VSCode) does not support pydantic typechecking +# if typing.TYPE_CHECKING: +# from dataclasses import dataclass +# else: +# from pydantic.dataclasses import dataclass +from pydantic.dataclasses import dataclass + +from ..exceptions import DuplicateApplicationRegistration +from .core import AlgorithmConfiguration, GeneratorAlgorithm + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class ConfigurationTuple(NamedTuple): + """Attributes to uniquely describe an AlgorithmConfiguration.""" + + algorithm_type: str + domain: str + algorithm_name: str + algorithm_application: str + + +class AnnotationTuple(NamedTuple): + annotation: type + default_value: Any # TODO serializable type? + + +@vanilla_dataclass +class AlgorithmApplication: + """Collect all needed to run an application.""" + + algorithm_class: Type[GeneratorAlgorithm] + configuration_class: Type[AlgorithmConfiguration] + parameters_dict: Dict[str, AnnotationTuple] = field(default_factory=dict) + # includes algorithm_version: str + + +class RegistryDict(Dict[ConfigurationTuple, AlgorithmApplication]): + """Dict that raises when reassigning an existing key.""" + + def __setitem__(self, key, value): + if self.__contains__(key): + raise DuplicateApplicationRegistration( + title="Applications exists", + detail=f"key {key} was already registered and would override another application.", + ) + # if it's really needed for some reason, delete the item first, then add it. + else: + super().__setitem__(key, value) + + +class ApplicationsRegistry: + """Registry to collect "applications" and make them accessible. + + An application denotes the combination of an + :class:`AlgorithmConfiguration` and a + :class:`GeneratorAlgorithm`. + """ + + # NOTE on import of registy also ensure import of modules to populate applications + applications: RegistryDict = RegistryDict() + + @classmethod + def _register_application( + cls, + algorithm_class: Type[GeneratorAlgorithm], + algorithm_configuration_class: Type[AlgorithmConfiguration], + ): + # testing that configuration class is callable without arguments + try: + algorithm_configuration_class() + except pydantic.ValidationError as e: + logger.exception(e) + config_tuple = cls.configuration_class_as_tuple(algorithm_configuration_class) + cls.applications[config_tuple] = AlgorithmApplication( + algorithm_class=algorithm_class, + configuration_class=algorithm_configuration_class, + ) + + @classmethod + def register_algorithm_application( + cls, + algorithm_class: Type[GeneratorAlgorithm], + as_algorithm_application: Optional[str] = None, + ) -> Callable[[Type[AlgorithmConfiguration]], Type[AlgorithmConfiguration]]: + """Complete and register a configuration via decoration. + + Args: + algorithm_class: The algorithm that uses the configuration. + as_algorithm_application: Optional application name to use instead of + the configurations class name. + + Returns: + A function to complete the configuration class' attributes to reflect + the matching GeneratorAlgorithm and application. The final class is + registered and returned. + + Example: + as decorator:: + + from gt4sd.algorithms.registry import ApplicationsRegistry + + + @ApplicationsRegistry.register_algorithm_application(SomeAlgorithm) + class SomeApplication(AlgorithmConfiguration): + algorithm_type: ClassVar[str] = 'conditional_generation' + domain: ClassVar[str] = 'materials' + algorithm_version: str = 'v0' + + some_more_serializable_arguments_with_defaults: int = 42 + + Example: + directly, here for an additional algorithm application with the same + algorithm:: + + AnotherApplication = ApplicationsRegistry.register_algorithm_application( + SomeAlgorithm, 'AnotherApplication' + )(SomeApplication) + """ + + def decorator( + configuration_class: Type[AlgorithmConfiguration], + ) -> Type[AlgorithmConfiguration]: + """Complete the configuration class' attributes and register the class. + + Args: + configuration_class: class to complete. + + Returns: + a completed class. + """ + VanillaConfiguration = make_dataclass( + cls_name=configuration_class.__name__, + # call `@dataclass` for users to avoid confusion + bases=(vanilla_dataclass(configuration_class),), + fields=[ + ( + "algorithm_name", # type: ignore + ClassVar[str], + field(default=algorithm_class.__name__), # type: ignore + ), + ( + "algorithm_application", # type: ignore + ClassVar[str], + field( + default=( + as_algorithm_application or configuration_class.__name__ # type: ignore + ) + ), + ), + ], # type: ignore + ) + + PydanticConfiguration: Type[AlgorithmConfiguration] = dataclass( # type: ignore + VanillaConfiguration + ) + # get missing entries + missing_in__dict__ = [ + key + for key in configuration_class.__dict__ + if key not in PydanticConfiguration.__dict__ + ] + + update_wrapper( + wrapper=PydanticConfiguration, + wrapped=configuration_class, + assigned=missing_in__dict__ + list(WRAPPER_ASSIGNMENTS), + updated=(), # default of '__dict__' does not apply here, see missing_in__dict__ + ) + + cls._register_application(algorithm_class, PydanticConfiguration) + + return PydanticConfiguration + + return decorator + + @staticmethod + def configuration_class_as_tuple( + algorithm_configuration_class: Type[AlgorithmConfiguration], + ) -> "ConfigurationTuple": + """Get a hashable identifier per application.""" + return ConfigurationTuple( + algorithm_type=algorithm_configuration_class.algorithm_type, + domain=algorithm_configuration_class.domain, + algorithm_name=algorithm_configuration_class.algorithm_name, + algorithm_application=algorithm_configuration_class.algorithm_application, + ) + + @classmethod + def get_application( + cls, + algorithm_type: str, + domain: str, + algorithm_name: str, + algorithm_application: str, + ) -> AlgorithmApplication: + return cls.applications[ + ConfigurationTuple( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + ] + + @classmethod + def get_matching_configuration_defaults( + cls, + algorithm_type: str, + domain: str, + algorithm_name: str, + algorithm_application: str, + ) -> Dict[str, AnnotationTuple]: + Configuration = cls.get_application( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ).configuration_class + + defaults_dict = {} + for ( + argument, + default_value, + ) in Configuration().__dict__.items(): + defaults_dict[argument] = AnnotationTuple( + annotation=Configuration.__annotations__[argument], + default_value=default_value, + ) + return defaults_dict + + @classmethod + def get_matching_configuration_schema( + cls, + algorithm_type: str, + domain: str, + algorithm_name: str, + algorithm_application: str, + ) -> Dict[str, Any]: + Configuration = cls.get_application( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ).configuration_class + return Configuration.__pydantic_model__.schema() # type: ignore + + @classmethod + def get_configuration_instance( + cls, + algorithm_type: str, + domain: str, + algorithm_name: str, + algorithm_application: str, + *args, + **kwargs, + ) -> AlgorithmConfiguration: + """Create an instance of the matching AlgorithmConfiguration from the ApplicationsRegistry. + + Args: + algorithm_type: general type of generative algorithm. + domain: general application domain. Hints at input/output types. + algorithm_name: name of the algorithm to use with this configuration. + algorithm_application: unique name for the application that is the use of this + configuration together with a specific algorithm. + algorithm_version: to differentiate between different versions of an application. + *args: additional positional arguments passed to the configuration. + **kwargs: additional keyword arguments passed to the configuration. + + Returns: + an instance of the configuration. + """ + Configuration = cls.get_application( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ).configuration_class + return Configuration(*args, **kwargs) + + @classmethod + def get_application_instance( + cls, + algorithm_type: str, + domain: str, + algorithm_name: str, + algorithm_application: str, + target: Any = None, + **kwargs, + ) -> GeneratorAlgorithm: + """Instantiate an algorithm via a matching application from the ApplicationsRegistry. + + Additional arguments are passed to the configuration and override any arguments + in the ApplicationsRegistry. + + Args: + algorithm_type: general type of generative algorithm. + domain: general application domain. Hints at input/output types. + algorithm_name: name of the algorithm to use with this configuration. + algorithm_application: unique name for the application that is the use of this + configuration together with a specific algorithm. + algorithm_version: to differentiate between different versions of an application. + target: optional context or condition for the generation. + **kwargs: additional keyword arguments passed to the configuration. + + Returns: + an instance of a generative algorithm ready to sample from. + """ + application_tuple = cls.get_application( + algorithm_type=algorithm_type, + domain=domain, + algorithm_name=algorithm_name, + algorithm_application=algorithm_application, + ) + parameters = { + key: annotation_tuple.default_value + for key, annotation_tuple in application_tuple.parameters_dict.items() + } + parameters.update(kwargs) + + return application_tuple.algorithm_class( + configuration=application_tuple.configuration_class(**parameters), + target=target, + ) + + @classmethod + def list_available(cls) -> List[Dict[str, str]]: + available = [] + for config_tuple, application in cls.applications.items(): + available.extend( + [ + dict(**config_tuple._asdict(), algorithm_version=version) + for version in application.configuration_class.list_versions() + ] + ) + return available diff --git a/src/gt4sd/algorithms/tests/__init__.py b/src/gt4sd/algorithms/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/algorithms/tests/test_config.py b/src/gt4sd/algorithms/tests/test_config.py new file mode 100644 index 000000000..916d75517 --- /dev/null +++ b/src/gt4sd/algorithms/tests/test_config.py @@ -0,0 +1,47 @@ +"""Tests for AlgorithmConfiguration.""" + +import os +import shutil +from typing import ClassVar + +import pytest + +from gt4sd.algorithms.core import AlgorithmConfiguration +from gt4sd.configuration import GT4SDConfiguration + +gt4sd_configuration_instance = GT4SDConfiguration.get_instance() + + +@pytest.fixture() +def development_version_path(): + # setup + path = os.path.join( + gt4sd_configuration_instance.gt4sd_local_cache_path, + gt4sd_configuration_instance.gt4sd_local_cache_path_algorithms, + "dummy", + "algorithm", + "config", + "development", + ) + os.makedirs(path, exist_ok=False) + # test + yield path + # teardown + shutil.rmtree( + os.path.join( + gt4sd_configuration_instance.gt4sd_local_cache_path, + gt4sd_configuration_instance.gt4sd_local_cache_path_algorithms, + "dummy", + ) + ) + + +def test_list_versions_local_only(development_version_path): + class Config(AlgorithmConfiguration): + algorithm_type: ClassVar[str] = "dummy" + domain: ClassVar[str] = "" + algorithm_name: ClassVar[str] = "algorithm" + algorithm_application: ClassVar[str] = "config" + algorithm_version: str = "development" + + assert "development" in Config.list_versions() diff --git a/src/gt4sd/algorithms/tests/test_registry.py b/src/gt4sd/algorithms/tests/test_registry.py new file mode 100644 index 000000000..96a5afb72 --- /dev/null +++ b/src/gt4sd/algorithms/tests/test_registry.py @@ -0,0 +1,138 @@ +"""Tests for registry that are independent of specific registrations.""" + +import pickle +from typing import ClassVar + +import pytest +from pydantic import ValidationError + +from gt4sd.algorithms.core import AlgorithmConfiguration, GeneratorAlgorithm +from gt4sd.algorithms.registry import ApplicationsRegistry +from gt4sd.exceptions import DuplicateApplicationRegistration + +# there are at least 2 available versions, 1 per PaccMannRL configuration +AT_LEAST = 2 + + +def assert_pickable(obj): + pickled_obj = pickle.dumps(obj) + restored_obj = pickle.loads(pickled_obj) + + assert restored_obj.algorithm_version == "test" + assert restored_obj == obj + + return restored_obj + + +def test_list_available_s3(): + len(ApplicationsRegistry.list_available()) + assert len(ApplicationsRegistry.list_available()) >= AT_LEAST + + +def test_list_available_local_via_S3SyncError(mock_wrong_s3_env): + assert len(ApplicationsRegistry.list_available()) >= AT_LEAST + + +def test_list_available_local_via_KeyError(mock_missing_s3_env): + assert len(ApplicationsRegistry.list_available()) >= AT_LEAST + + +def test_inherited_validation(): + Config = next(iter(ApplicationsRegistry.applications.values())).configuration_class + with pytest.raises( + ValidationError, match="algorithm_version\n +none is not an allowed value" + ): + Config(algorithm_version=None) # type: ignore + + # NOTE: values convertible to string will not raise! + Config(algorithm_version=5) # type: ignore + + +def test_validation(): + with pytest.raises( + ValidationError, match="batch_size\n +value is not a valid integer" + ): + ApplicationsRegistry.get_configuration_instance( + algorithm_type="conditional_generation", + domain="materials", + algorithm_name="PaccMannRL", + algorithm_application="PaccMannRLProteinBasedGenerator", + batch_size="wrong_type", + ) + + +def test_pickable_wrapped_configurations(): + # https://github.com/samuelcolvin/pydantic/issues/2111 + Config = next(iter(ApplicationsRegistry.applications.values())).configuration_class + restored_obj = assert_pickable(Config(algorithm_version="test")) + + # wrong type assignment, but we did not configure it to raise here: + restored_obj.algorithm_version = object + # ensure the restored dataclass is still a pydantic dataclass (mimic validation) + _, optional_errors = restored_obj.__pydantic_model__.__fields__.get( + "algorithm_version" + ).validate( + restored_obj.algorithm_version, + restored_obj.__dict__, + loc="algorithm_version", + cls=restored_obj.__class__, + ) + assert optional_errors is not None + + +def test_multiple_registration(): + class OtherAlgorithm(GeneratorAlgorithm): + pass + + @ApplicationsRegistry.register_algorithm_application( + GeneratorAlgorithm # type:ignore + ) + @ApplicationsRegistry.register_algorithm_application(OtherAlgorithm) # type:ignore + class Config(AlgorithmConfiguration): + algorithm_type: ClassVar[str] = "dummy" + domain: ClassVar[str] = "" + algorithm_version: str = "development" + + # the class wrapping was applied twice + config_class = ApplicationsRegistry.get_application( + algorithm_type="dummy", + domain="", + algorithm_name="GeneratorAlgorithm", + algorithm_application="Config", + ).configuration_class + assert config_class is Config + assert config_class.algorithm_name == "GeneratorAlgorithm" + assert config_class.algorithm_application == "Config" + # __wrapped__? + + # retrieve singly wrapped config + other_config_class = ApplicationsRegistry.get_application( + algorithm_type="dummy", + domain="", + algorithm_name="OtherAlgorithm", + algorithm_application="Config", + ).configuration_class + assert other_config_class is not Config + assert other_config_class.algorithm_name == "OtherAlgorithm" + assert other_config_class.algorithm_application == "Config" + + # registering Config directly and with explicit algorithm_application + ExplicitConfig = ApplicationsRegistry.register_algorithm_application( + GeneratorAlgorithm, # type:ignore + as_algorithm_application="ExplicitApplication", + )(Config) + explicit_config_class = ApplicationsRegistry.get_application( + algorithm_type="dummy", + domain="", + algorithm_name="GeneratorAlgorithm", + algorithm_application="ExplicitApplication", + ).configuration_class + assert explicit_config_class is ExplicitConfig + assert explicit_config_class.algorithm_name == "GeneratorAlgorithm" + assert explicit_config_class.algorithm_application == "ExplicitApplication" + + # overwriting value in applications is not allowed, applications are unique + with pytest.raises(DuplicateApplicationRegistration): + ApplicationsRegistry.register_algorithm_application( + GeneratorAlgorithm # type:ignore + )(Config) diff --git a/src/gt4sd/cli/__init__.py b/src/gt4sd/cli/__init__.py new file mode 100644 index 000000000..15d0664f6 --- /dev/null +++ b/src/gt4sd/cli/__init__.py @@ -0,0 +1 @@ +"""GT4SD CLI module initialization.""" diff --git a/src/gt4sd/cli/argument_parser.py b/src/gt4sd/cli/argument_parser.py new file mode 100644 index 000000000..354d4b6e5 --- /dev/null +++ b/src/gt4sd/cli/argument_parser.py @@ -0,0 +1,150 @@ +"""Argument parser for training pipelines.""" + + +import dataclasses +import re +from argparse import ArgumentTypeError +from enum import Enum +from functools import partial +from typing import Any, List, NewType, Optional, Type, Union + +from transformers import HfArgumentParser + + +def none_checker_bool(val: Union[bool, str]) -> Union[bool, None]: + """Check given bool argument for None. + + Args: + val: model arguments passed to the configuration. + Returns: + Bool value or None. + """ + if not val: + return None + if isinstance(val, bool): + return val + if val.lower() in ("yes", "true", "t", "y", "1"): + return True + elif val.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise ArgumentTypeError( + f"Truthy value expected: got {val} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." + ) + + +def none_checker(val: Any, dtype: Type) -> Any: + """Check given argument for None. + + Args: + val: model arguments passed to the configuration. + dtype: expected argument type. + + Returns: + Value casted in the expected type or None. + """ + if not val or val == "none": + return None + return dtype(val) + + +DataClass = NewType("DataClass", Any) # type: ignore +DataClassType = NewType("DataClassType", Any) # type: ignore + + +class ArgumentParser(HfArgumentParser): + """ArgumentParser inherited from hf's parser with modified dataclass arguments addition for better handling of None values.""" + + def _add_dataclass_arguments(self, dtype: DataClassType) -> None: + """Add a dataclass arguments. + + Args: + dtype: data class type. + """ + + if hasattr(dtype, "_argument_group_name"): + parser = self.add_argument_group(dtype._argument_group_name) + else: + parser = self # type: ignore + for field in dataclasses.fields(dtype): + if not field.init: + continue + field_name = f"--{field.name}" + kwargs = field.metadata.copy() # type: ignore + # field.metadata is not used at all by Data Classes, + # it is provided as a third-party extension mechanism. + if isinstance(field.type, str): + raise ImportError( + "This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563)," + "which can be opted in from Python 3.7 with `from __future__ import annotations`." + "We will add compatibility when Python 3.9 is released." + ) + typestring = str(field.type) + for prim_type in (int, float, str): + for collection in (List,): + if ( + typestring == f"typing.Union[{collection[prim_type]}, NoneType]" # type: ignore + or typestring == f"typing.Optional[{collection[prim_type]}]" # type: ignore + ): + field.type = collection[prim_type] # type: ignore + if ( + typestring == f"typing.Union[{prim_type.__name__}, NoneType]" + or typestring == f"typing.Optional[{prim_type.__name__}]" + ): + field.type = prim_type + + if isinstance(field.type, type) and issubclass(field.type, Enum): + kwargs["choices"] = [x.value for x in field.type] + kwargs["type"] = type(kwargs["choices"][0]) + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + else: + kwargs["required"] = True + elif field.type is bool or field.type == Optional[bool]: + + if field.default is True: + parser.add_argument( + f"--no_{field.name}", + action="store_false", + dest=field.name, + **kwargs, + ) + + # Hack because type=bool in argparse does not behave as we want. + kwargs["type"] = none_checker_bool + if field.type is bool or ( + field.default is not None + and field.default is not dataclasses.MISSING + ): + # Default value is False if we have no default when of type bool. + default = ( + False if field.default is dataclasses.MISSING else field.default + ) + # This is the value that will get picked if we don't include --field_name in any way + kwargs["default"] = default + # This tells argparse we accept 0 or 1 value after --field_name + kwargs["nargs"] = "?" + # This is the value that will get picked if we do --field_name (without value) + kwargs["const"] = True + elif ( + hasattr(field.type, "__origin__") + and re.search(r"^typing\.List\[(.*)\]$", str(field.type)) is not None + ): + kwargs["nargs"] = "+" + kwargs["type"] = partial(none_checker, dtype=field.type.__args__[0]) + assert all( + x == kwargs["type"] for x in field.type.__args__ + ), f"{field.name} cannot be a List of mixed types" + if field.default_factory is not dataclasses.MISSING: # type: ignore + kwargs["default"] = field.default_factory() # type: ignore + elif field.default is dataclasses.MISSING: + kwargs["required"] = True + else: + kwargs["type"] = partial(none_checker, dtype=field.type) + if field.default is not dataclasses.MISSING: + kwargs["default"] = field.default + elif field.default_factory is not dataclasses.MISSING: # type: ignore + kwargs["default"] = field.default_factory() # type: ignore + else: + kwargs["required"] = True + parser.add_argument(field_name, **kwargs) diff --git a/src/gt4sd/cli/hf_to_st_converter.py b/src/gt4sd/cli/hf_to_st_converter.py new file mode 100755 index 000000000..cce6c4f21 --- /dev/null +++ b/src/gt4sd/cli/hf_to_st_converter.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +"""Transformers pretrained model to SentenceTransformer model converter.""" + + +import json +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import cast + +from sentence_transformers import SentenceTransformer, __version__, models + +from .argument_parser import ArgumentParser, DataClassType + + +@dataclass +class TransformersToSentenceTransformersArguments: + """Transformers to Sentence Transformers converter arguments.""" + + __name__ = "hf_to_st_converter_args" + + model_name_or_path: str = field( + metadata={"help": "HF model name or path."}, + ) + pooling: str = field( + metadata={ + "help": "Comma separated pooling modes. Supported types: cls, max, mean, mean_sqrt." + }, + ) + output_path: str = field( + metadata={"help": "Path to the converted model."}, + ) + + +def main() -> None: + """Convert HF pretrained model to SentenceTransformer. + + Create a SentenceTransformer model having a given HF model as + word embedding model plus an optional pooling layer. We can + also concatenate multiple poolings together. + + Parsing from the command line the following parameters: + - HF pretrained model to be used as word embedding model. + - the pooling mode (more than one can be provided as a list), the implemented + options are "cls", "max", "mean", "mean" and "sqrt". + - path to save the generated SentenceTransformer model. + """ + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + arguments = ArgumentParser( + cast(DataClassType, TransformersToSentenceTransformersArguments) + ).parse_args_into_dataclasses(return_remaining_strings=True)[0] + + model_name_or_path = arguments.model_name_or_path + pooling = [ + polling_argument.strip() for polling_argument in arguments.pooling.split(",") + ] + output_path = arguments.output_path + + word_embedding_model = models.Transformer(model_name_or_path) + + pooling_mode_cls_token = False + pooling_mode_max_tokens = False + pooling_mode_mean_tokens = False + pooling_mode_mean_sqrt_len_tokens = False + + if "cls" in pooling: + pooling_mode_cls_token = True + if "max" in pooling: + pooling_mode_max_tokens = True + if "mean" in pooling: + pooling_mode_mean_tokens = True + if "mean_sqrt" in pooling: + pooling_mode_mean_sqrt_len_tokens = True + + pooling_model = models.Pooling( + word_embedding_model.get_word_embedding_dimension(), + pooling_mode=None, + pooling_mode_cls_token=pooling_mode_cls_token, + pooling_mode_max_tokens=pooling_mode_max_tokens, + pooling_mode_mean_tokens=pooling_mode_mean_tokens, + pooling_mode_mean_sqrt_len_tokens=pooling_mode_mean_sqrt_len_tokens, + ) + + model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) + + model.save(output_path) + + config_filepath = os.path.join(output_path, "config.json") + if os.path.exists(config_filepath): + with open(config_filepath) as fp: + config = json.load(fp) + config["__version__"] = __version__ + with open(config_filepath, "wt") as fp: + json.dump(config, fp, indent=2) + + +if __name__ == "__main__": + main() diff --git a/src/gt4sd/cli/load_arguments_from_dataclass.py b/src/gt4sd/cli/load_arguments_from_dataclass.py new file mode 100644 index 000000000..06bcf10ad --- /dev/null +++ b/src/gt4sd/cli/load_arguments_from_dataclass.py @@ -0,0 +1,87 @@ +"""Functions to facilitate conversion from dataclasses to training descriptions.""" + + +from dataclasses import _MISSING_TYPE, fields +from typing import Any, Dict, Optional, Type, Union + + +def find_type(input_type: Type) -> Optional[str]: + """Convert type class to string. + + Args: + input_type: Type to be converted to string. + + Returns: + String of the type or None if the given type is not supported. + """ + field_type = None + if input_type is str: + field_type = "string" + elif input_type is int: + field_type = "integer" + elif input_type is float: + field_type = "number" + elif input_type is bool: + field_type = "boolean" + + return field_type + + +def extract_fields_from_class( + dataclass: Type, +) -> Dict[str, Any]: + """Extract arguments from dataclass. + + Args: + dataclass: Dataclass to contains the arguments. + + Returns: + Dictionary of the existing arguments including their type, description and default value. + """ + + # assign type and description + arg_fields = { + field.name: {"type": field.type, "description": field.metadata["help"]} + for field in fields(dataclass) + } + + # assign default values + for field in fields(dataclass): + + if not isinstance(field.default, _MISSING_TYPE): + + if field.default is None: + field.default = "none" + + arg_fields[field.name]["default"] = field.default + + # convert type to str + for field_name in arg_fields: + + field_type = find_type(arg_fields[field_name]["type"]) + + if field_type: + + arg_fields[field_name]["type"] = field_type + + elif ( + hasattr(arg_fields[field_name]["type"], "__origin__") + and arg_fields[field_name]["type"].__origin__ is Union + ): + + types = [ + find_type(type) for type in arg_fields[field_name]["type"].__args__ + ] + types = [type for type in types if type is not None] + + if len(types) == 1: + arg_fields[field_name]["type"] = types[0] + else: + raise ValueError(f"{arg_fields[field_name]['type']} not supported") + + else: + raise ValueError( + f" argument {field_name}: {arg_fields[field_name]['type']} not supported" + ) + + return arg_fields diff --git a/src/gt4sd/cli/pl_to_hf_converter.py b/src/gt4sd/cli/pl_to_hf_converter.py new file mode 100755 index 000000000..3c8350b62 --- /dev/null +++ b/src/gt4sd/cli/pl_to_hf_converter.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +"""Pytorch lightning checkpoint to HF transformers converter.""" + + +import logging +import sys +from dataclasses import dataclass, field +from typing import Optional, cast + +from transformers import Trainer, TrainingArguments + +from ..training_pipelines.pytorch_lightning.language_modeling.models import ( + LM_MODULE_FACTORY, +) +from .argument_parser import ArgumentParser, DataClassType + + +@dataclass +class PyTorchLightningToTransformersArguments: + """PyTorchLightning to Transformers converter arguments.""" + + __name__ = "pl_to_hf_converter_args" + + training_type: str = field( + metadata={ + "help": f"Training type of the converted model, supported types: {', '.join(LM_MODULE_FACTORY.keys())}." + }, + ) + model_name_or_path: str = field( + metadata={"help": "Model name or path."}, + ) + ckpt: str = field( + metadata={"help": "Path to checkpoint."}, + ) + output_path: str = field( + metadata={"help": "Path to the converted model."}, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Tokenizer name or path. If not provided defaults to model_name_or_path." + }, + ) + + +def main() -> None: + """Convert pytorch lightning checkpoint to HF transformers model. + + Parsing from the command line the following parameters: + - training type of the given checkpoint. + - model name or path, a HF's model name. + - tokenizer name or path, a HF's tokenizer name. + - path of the checkpoint. + - path where the HF model will be saved. + + Raises: + ValueError: in case the provided training type is not supported. + """ + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + arguments = ArgumentParser( + cast(DataClassType, PyTorchLightningToTransformersArguments) + ).parse_args_into_dataclasses(return_remaining_strings=True)[0] + + training_type = arguments.training_type + model_name_or_path = arguments.model_name_or_path + tokenizer_name_or_path = arguments.tokenizer_name_or_path + if tokenizer_name_or_path is None: + tokenizer_name_or_path = model_name_or_path + ckpt = arguments.ckpt + output_path = arguments.output_path + + if training_type not in LM_MODULE_FACTORY: + ValueError( + f"LM training type {training_type} is not supported. Supported types: {', '.join(LM_MODULE_FACTORY.keys())}." + ) + model_module_class = LM_MODULE_FACTORY[training_type] + + model_module = model_module_class.load_from_checkpoint( + ckpt, + model_args={ + "model_name_or_path": model_name_or_path, + "tokenizer": tokenizer_name_or_path, + }, + ) + + trainer = Trainer( + model=model_module.model, + tokenizer=model_module.tokenizer, + args=TrainingArguments(output_dir=output_path), + ) + trainer.save_model() + + +if __name__ == "__main__": + main() diff --git a/src/gt4sd/cli/trainer.py b/src/gt4sd/cli/trainer.py new file mode 100755 index 000000000..89a555d15 --- /dev/null +++ b/src/gt4sd/cli/trainer.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +"""Run training pipelines for the GT4SD.""" + + +import logging +import sys +from dataclasses import dataclass, field +from typing import IO, Optional, Tuple, cast + +from ..training_pipelines import ( + TRAINING_PIPELINE_ARGUMENTS_MAPPING, + TRAINING_PIPELINE_MAPPING, +) +from ..training_pipelines.core import TrainingPipelineArguments +from .argument_parser import ArgumentParser, DataClass, DataClassType + +logger = logging.getLogger(__name__) + +SUPPORTED_TRAINING_PIPELINES = sorted( + list(set(TRAINING_PIPELINE_ARGUMENTS_MAPPING) & set(TRAINING_PIPELINE_MAPPING)) +) + + +@dataclass +class TrainerArguments: + """Trainer arguments.""" + + __name__ = "trainer_base_args" + + training_pipeline_name: str = field( + metadata={ + "help": f"Training type of the converted model, supported types: {', '.join(SUPPORTED_TRAINING_PIPELINES)}." + }, + ) + configuration_file: Optional[str] = field( + default=None, + metadata={ + "help": "Configuration file for the trainining. It can be used to completely by-pass pipeline specific arguments." + }, + ) + + +class TrainerArgumentParser(ArgumentParser): + """Argument parser using a custom help logic.""" + + def print_help(self, file: Optional[IO[str]] = None) -> None: + """Print help checking dynamically whether a specific pipeline is passed. + + Args: + file: an optional I/O stream. Defaults to None, a.k.a., stdout and stderr. + """ + try: + help_args_set = {"-h", "--help"} + if ( + len(set(sys.argv).union(help_args_set)) < len(help_args_set) + 2 + ): # considering filename + super().print_help() + return + args = [arg for arg in sys.argv if arg not in help_args_set] + parsed_arguments = super().parse_args_into_dataclasses( + args=args, return_remaining_strings=True + ) + trainer_arguments = None + for arguments in parsed_arguments: + if arguments.__name__ == "trainer_base_args": + trainer_arguments = arguments + break + if trainer_arguments: + trainer_arguments.training_pipeline_name + training_pipeline_arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING.get( + trainer_arguments.training_pipeline_name, TrainingPipelineArguments + ) + parser = ArgumentParser( + tuple( + [TrainerArguments, *training_pipeline_arguments] # type:ignore + ) + ) + parser.print_help() + except Exception: + super().print_help() + + def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: + """Overriding default .json parser. + + It by-passes all command line arguments and simply add the training pipeline. + + Args: + json_file: JSON file containing pipeline configuration parameters. + + Returns: + parsed arguments in a tuple of dataclasses. + """ + number_of_dataclass_types = len(self.dataclass_types) # type:ignore + self.dataclass_types = [ + dataclass_type + for dataclass_type in self.dataclass_types + if "gt4sd.cli.trainer.TrainerArguments" not in str(dataclass_type) + ] + try: + parsed_arguments = super().parse_json_file(json_file=json_file) + except Exception: + logger.exception( + f"error parsing configuration file: {json_file}, printing error and exiting" + ) + sys.exit(1) + if number_of_dataclass_types > len(self.dataclass_types): + self.dataclass_types.insert(0, cast(DataClassType, TrainerArguments)) + return parsed_arguments + + +def main() -> None: + """ + Run a training pipeline. + + Raises: + ValueError: in case the provided training pipeline provided is not supported. + """ + logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + base_args = TrainerArgumentParser( + cast(DataClassType, TrainerArguments) + ).parse_args_into_dataclasses(return_remaining_strings=True)[0] + training_pipeline_name = base_args.training_pipeline_name + if training_pipeline_name not in set(SUPPORTED_TRAINING_PIPELINES): + ValueError( + f"Training pipeline {training_pipeline_name} is not supported. Supported types: {', '.join(SUPPORTED_TRAINING_PIPELINES)}." + ) + arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING[training_pipeline_name] + parser = TrainerArgumentParser(tuple([TrainerArguments, *arguments])) # type:ignore + + configuration_filepath = base_args.configuration_file + if configuration_filepath: + args = parser.parse_json_file(json_file=configuration_filepath) + else: + args = parser.parse_args_into_dataclasses(return_remaining_strings=True) + config = { + arg.__name__: arg.__dict__ + for arg in args + if isinstance(arg, TrainingPipelineArguments) and isinstance(arg.__name__, str) + } + + pipeline = TRAINING_PIPELINE_MAPPING[training_pipeline_name] + pipeline().train(**config) + + +if __name__ == "__main__": + main() diff --git a/src/gt4sd/configuration.py b/src/gt4sd/configuration.py new file mode 100644 index 000000000..3c5b07987 --- /dev/null +++ b/src/gt4sd/configuration.py @@ -0,0 +1,123 @@ +"""Module configuration.""" + +import logging +import os +from functools import lru_cache +from typing import Optional, Set + +from pydantic import BaseSettings + +from .s3 import GT4SDS3Client, S3SyncError, sync_folder_with_s3 + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class GT4SDConfiguration(BaseSettings): + """GT4SDConfiguration settings from environment variables. + + Default configurations for gt4sd including a read-only COS for algorithms' artifacts. + """ + + gt4sd_local_cache_path: str = os.path.join(os.path.expanduser("~"), ".gt4sd") + gt4sd_local_cache_path_algorithms: str = "algorithms" + gt4sd_max_number_of_stuck_calls: int = 50 + gt4sd_max_number_of_samples: int = 1000000 + gt4sd_max_runtime: int = 86400 + gt4sd_s3_host: str = "s3.mil01.cloud-object-storage.appdomain.cloud" + gt4sd_s3_access_key: str = "a19f93a1c67949f1a31db38e58bcb7e8" + gt4sd_s3_secret_key: str = "5748375c761a4f09c30a68cd15e218e3b27ca3e2aebd7726" + gt4sd_s3_secure: bool = True + gt4sd_s3_bucket: str = "algorithms" + + class Config: + # immutable and in turn hashable, that is required for lru_cache + frozen = True + + @staticmethod + @lru_cache(maxsize=None) + def get_instance() -> "GT4SDConfiguration": + return GT4SDConfiguration() + + +gt4sd_configuration_instance = GT4SDConfiguration.get_instance() +logger.info( + f"using as local cache path: {gt4sd_configuration_instance.gt4sd_local_cache_path}" +) +try: + os.makedirs(gt4sd_configuration_instance.gt4sd_local_cache_path) +except FileExistsError: + logger.debug("local cache path already exists") + + +def sync_algorithm_with_s3(prefix: Optional[str] = None) -> str: + """Sync an algorithm in the local cache using environment variables. + + Args: + prefix: the relative path in the bucket (both + on S3 and locally) to match files to download. Defaults to None. + + Returns: + str: local path using the prefix. + """ + folder_path = os.path.join( + gt4sd_configuration_instance.gt4sd_local_cache_path, + gt4sd_configuration_instance.gt4sd_local_cache_path_algorithms, + ) + try: + sync_folder_with_s3( + host=gt4sd_configuration_instance.gt4sd_s3_host, + access_key=gt4sd_configuration_instance.gt4sd_s3_access_key, + secret_key=gt4sd_configuration_instance.gt4sd_s3_secret_key, + bucket=gt4sd_configuration_instance.gt4sd_s3_bucket, + folder_path=folder_path, + prefix=prefix, + secure=gt4sd_configuration_instance.gt4sd_s3_secure, + ) + except S3SyncError: + logger.exception("error in syncing the cache with S3") + return os.path.join(folder_path, prefix) if prefix is not None else folder_path + + +def get_cached_algorithm_path(prefix: Optional[str] = None) -> str: + return ( + os.path.join( + gt4sd_configuration_instance.gt4sd_local_cache_path, + gt4sd_configuration_instance.gt4sd_local_cache_path_algorithms, + prefix, + ) + if prefix is not None + else os.path.join( + gt4sd_configuration_instance.gt4sd_local_cache_path, + gt4sd_configuration_instance.gt4sd_local_cache_path_algorithms, + ) + ) + + +def get_algorithm_subdirectories_with_s3(prefix: Optional[str] = None) -> Set[str]: + + try: + host = gt4sd_configuration_instance.gt4sd_s3_host + access_key = gt4sd_configuration_instance.gt4sd_s3_access_key + secret_key = gt4sd_configuration_instance.gt4sd_s3_secret_key + secure = gt4sd_configuration_instance.gt4sd_s3_secure + client = GT4SDS3Client( + host=host, access_key=access_key, secret_key=secret_key, secure=secure + ) + bucket = gt4sd_configuration_instance.gt4sd_s3_bucket + return client.list_directories(bucket=bucket, prefix=prefix) + except Exception: + logger.exception("generic syncing error") + raise S3SyncError( + "CacheSyncingError", + f"error in getting directories of prefix={prefix} with host={host} access_key={access_key} secret_key={secret_key} secure={secure} bucket={bucket}", + ) + + +def get_algorithm_subdirectories_in_cache(prefix: Optional[str] = None) -> Set[str]: + path = get_cached_algorithm_path(prefix=prefix) + try: + _, dirs, _ = next(iter(os.walk(path))) + return set(dirs) + except StopIteration: + return set() diff --git a/src/gt4sd/conftest.py b/src/gt4sd/conftest.py new file mode 100644 index 000000000..4102fc17e --- /dev/null +++ b/src/gt4sd/conftest.py @@ -0,0 +1,15 @@ +"""Make pytest fixtures available to multiple test directories.""" + +import pytest + + +@pytest.fixture +def mock_wrong_s3_env(monkeypatch): + """Changes an environment variable to break the s3 connection.""" + monkeypatch.setenv("GT4SD_S3_SECRET_KEY", "(╯°□°)╯︵ ┻━┻") + + +@pytest.fixture +def mock_missing_s3_env(monkeypatch): + """Deletes an environment variable to break the s3 connection.""" + monkeypatch.delenv("GT4SD_S3_SECRET_KEY") diff --git a/src/gt4sd/domains/__init__.py b/src/gt4sd/domains/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/domains/core.py b/src/gt4sd/domains/core.py new file mode 100644 index 000000000..0f0d75757 --- /dev/null +++ b/src/gt4sd/domains/core.py @@ -0,0 +1,7 @@ +"""Not so domain specific code.""" + +# NOTE these might change a lot +# - could be abstract base classes such that other implmentations +# can be simply registered, without explicit inheritance +# - could be simply NewType definitions if they do not implement functionality +# - organize/sort into domain specific files diff --git a/src/gt4sd/domains/materials/__init__.py b/src/gt4sd/domains/materials/__init__.py new file mode 100644 index 000000000..0054ea3c8 --- /dev/null +++ b/src/gt4sd/domains/materials/__init__.py @@ -0,0 +1,58 @@ +"""Types, classes, validation, etc. for the material domain.""" + +from typing import List, NewType, Tuple, Union + +import numpy as np +import pandas as pd +from rdkit import Chem + +from gt4sd.exceptions import InvalidItem + +# TODO setting to str directly requires no wrapping, so wrong strings could be passed +Protein = str # NewType('Protein', str) +SMILES = str # NewType('SMILES', str) +SmallMolecule = SMILES +Omics = Union[np.ndarray, pd.Series] +PAG = SMILES +Molecule = Union[SmallMolecule, Protein] +Sequence = str +Property = float + + +def check_smiles(smiles: SMILES): + try: + pass # TODO + except Exception: + raise InvalidItem(title="invalid SMILES", detail="Validation as SMILES failed.") + + +def validate_molecules( + smiles_list: List[SMILES], +) -> Tuple[List[Chem.rdchem.Mol], List[int]]: + """Validate molecules. + + Args: + smiles_list: list of SMILES representing molecules. + + Returns: + a tuple containing RDKit molecules and valid indexes. + """ + # generate molecules from SMILES + molecules = [ + Chem.MolFromSmiles(a_smiles, sanitize=True) for a_smiles in smiles_list + ] + # valid ids + valid_ids = [ + index for index, molecule in enumerate(molecules) if molecule is not None + ] + return molecules, valid_ids + + +Bounds = Tuple[int, int] # NewType('Bounds', Tuple[int, int]) + +PhotoacidityCondition = NewType("PhotoacidityCondition", Bounds) +# photoacidity_condition = PhotoacidityCondition( +# (0, 1) +# ) # PhotoacidityCondition(Bounds((0, 1))) if Bound was a new type + +ConditionPAG = Union[PhotoacidityCondition] diff --git a/src/gt4sd/domains/materials/protein_encoding.py b/src/gt4sd/domains/materials/protein_encoding.py new file mode 100644 index 000000000..d9e901dad --- /dev/null +++ b/src/gt4sd/domains/materials/protein_encoding.py @@ -0,0 +1,219 @@ +"""Data processing utilities.""" + +import inspect +from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast + +import numpy as np +import torch +from tape.datasets import pad_sequences +from tape.registry import registry +from tape.tokenizers import TAPETokenizer +from torch import nn + + +class PrimarySequenceEncoder(nn.Module): + """Model like class to create tape embeddings/encodings. + + This follows tapes implementation via `run_embed` closely, but removes + any seed/device/cuda handling (of model and batch). This can be done in + the training loop like for any other nn.Module. + + Example: + An example use with protein sequence dataset from `pytoda` (requires + mock/rdkit and pytoda>0.2) passing ids with the primary sequence:: + + import sys + from mock import Mock + sys.modules['rdkit'] = Mock() + sys.modules['rdkit.Chem'] = Mock() + from torch.utils.data import DataLoader + from pytoda.datasets.protein_sequence_dataset import protein_sequence_dataset + from pytoda.datasets.tests.test_protein_sequence_dataset import ( + FASTA_CONTENT_GENERIC, TestFileContent + ) + from pytoda.datasets.utils import keyed + + with TestFileContent(FASTA_CONTENT_GENERIC) as a_test_file: + sequence_dataset = keyed(protein_sequence_dataset( + a_test_file.filename, filetype='.fasta', backend='lazy' + )) + batch_size = 5 + dataloader = DataLoader(sequence_dataset, batch_size=batch_size) + + encoder = PrimarySequenceEncoder( + model_type='transformer', + from_pretrained='bert-base', + tokenizer='iupac', + log_level=logging.INFO, + ) + # sending encoder to cuda device should work, not tested + + loaded = next(iter(dataloader)) + print(loaded) + encoded, ids = encoder.forward(loaded) + print(ids) + print(encoded) + + However the forward call supports also not passing ids, but batch still + has to be wrapped as list (of length 1):: + + encoded, dummy_ids = PrimarySequenceEncoder().forward( + [ + ['MQNP', 'LLLLL'], # type: Sequence[str] + # sequence_ids may be missing here + ] + ) + """ + + def __init__( + self, + model_type: str = "transformer", + from_pretrained: Optional[str] = "bert-base", + model_config_file: Optional[str] = None, + # full_sequence_embed: bool = False, + tokenizer: str = "iupac", + ): + """Initialize the PrimarySequenceEncoder. + + Args: + model_type: Which type of model to create + (e.g. transformer, unirep, ...). Defaults to 'transformer'. + from_pretrained: either + a string with the `shortcut name` of a pre-trained model to + load from cache or download, e.g.: ``bert-base-uncased``, or + a path to a `directory` containing model weights saved using + :func:`tape.models.modeling_utils.ProteinConfig.save_pretrained`, + e.g.: ``./my_model_directory/``. + Defaults to 'bert-base'. + model_config_file: A json config file + that specifies hyperparameters. Defaults to None. + tokenizer: vocabulary name. Defaults to 'iupac'. + + Note: + tapes default seed would be 42 (see `tape.utils.set_random_seeds`) + """ + super().__init__() + # padding during forward goes through cpu (numpy) + self.device_indicator = nn.Parameter(torch.empty(0), requires_grad=False) + # dummy sequence_ids, so they are optional + self.next_dummy_id = 0 + + task_spec = registry.get_task_spec("embed") # task = 'embed' + # from tape.datasets import EmbedDataset + self.model = registry.get_task_model( + model_type, task_spec.name, model_config_file, from_pretrained + ) + + # to filter out batch items that aren't used in this model + # see `from_collated_batch` and `tape.training.ForwardRunner` + forward_arg_keys = inspect.getfullargspec(self.model.forward).args + self._forward_arg_keys = forward_arg_keys[1:] # remove self argument + assert "input_ids" in self._forward_arg_keys + + self.tokenizer = TAPETokenizer(vocab=tokenizer) + self.full_sequence_embed = False + + self.eval() + + def train(self, mode: bool): # type:ignore + """Avoid any setting to train mode.""" + return super().train(False) + + def generate_tokenized( + self, batch: List[Sequence[str]] + ) -> Iterator[Tuple[str, np.ndarray, np.ndarray]]: + # batch is list of len 2 (typically tuples[str] of length `batch_size`) + for item, sequence_id in zip(*batch): + token_ids = self.tokenizer.encode(item) + input_mask: np.ndarray = np.ones_like(token_ids) + yield sequence_id, token_ids, input_mask + + @classmethod + def collate_fn( + cls, batch: List[Tuple[str, np.ndarray, np.ndarray]] + ) -> Dict[str, Union[List[str], torch.Tensor]]: + # from tape.datasets.EmbedDataset because there it's not a classmethod + ids, tokens, input_mask = zip(*batch) + ids_list: List[str] = list(ids) + tokens_tensor: torch.Tensor = torch.from_numpy(pad_sequences(tokens)) + input_mask_tensor: torch.Tensor = torch.from_numpy(pad_sequences(input_mask)) + # on cpu now, is unavoidable as tokenizer and mask are in numpy. + return { + "ids": ids_list, + "input_ids": tokens_tensor, + "input_mask": input_mask_tensor, + } # type: ignore + + def from_collated_batch( + self, batch: Dict[str, Union[List[str], torch.Tensor]] + ) -> Dict[str, torch.Tensor]: + # filter arguments + batch_tensors: Dict[str, torch.Tensor] = { + name: tensor # type:ignore + for name, tensor in batch.items() + if name in self._forward_arg_keys + } + device = self.device_indicator.device + if device.type == "cuda": + batch_tensors = { + name: tensor.cuda(device=device, non_blocking=True) + for name, tensor in batch_tensors.items() + } + return batch_tensors + + def forward( # type:ignore + self, batch: List[Sequence[str]] + ) -> Tuple[torch.Tensor, List[str]]: + # batch: List[(primary_sequences,), (sequence_ids,))] of length 2 + # keys can be passed on by pytoda via keyed(ds: Keydataset[str]) + if len(batch) == 1: + # no sequence_ids passed + dummy_ids = self.get_dummy_ids(length=len(batch[0])) + batch.append(dummy_ids) + elif len(batch) == 2: + pass + else: + raise ValueError( + "batch should be of length 1 or 2, containing `primary_sequences` " + " and optionally `sequence_ids`." + ) + + with torch.no_grad(): + # Iterator[(sequence_id, token_ids, input_mask)] + batch_loader_like = self.generate_tokenized(batch) + batch_dict_with_ids: Dict[ + str, Union[List[str], torch.Tensor] + ] = self.collate_fn(list(batch_loader_like)) + ids: List[str] = cast(List[str], batch_dict_with_ids["ids"]) + batch_dict = self.from_collated_batch(batch_dict_with_ids) + # outputs = self.model(**batch_dict) + # pooled_embed = outputs[1] + sequence_embed = self.model(**batch_dict)[0] + sequence_lengths = batch_dict["input_mask"].sum(1) + + # can variable length slicing be done on the batch? + if not self.full_sequence_embed: + sequences_out: torch.Tensor = sequence_embed.new_empty( + # dimension of sequence length will be averaged out + size=sequence_embed.shape[::2] + ) + else: + raise NotImplementedError + + for i, (seqembed, length) in enumerate( + zip( + sequence_embed, + sequence_lengths, + ) + ): + seqembed = seqembed[: int(length)] + if not self.full_sequence_embed: + seqembed = seqembed.mean(0) + sequences_out[i, ...] = seqembed + + return sequences_out, ids + + def get_dummy_ids(self, length: int) -> Tuple[str, ...]: + first = self.next_dummy_id + self.next_dummy_id += length # before last + return tuple(map(str, range(first, self.next_dummy_id))) diff --git a/src/gt4sd/domains/materials/scorer.py b/src/gt4sd/domains/materials/scorer.py new file mode 100644 index 000000000..8ae098998 --- /dev/null +++ b/src/gt4sd/domains/materials/scorer.py @@ -0,0 +1,339 @@ +"""Implementation of Scorers""" + + +from functools import partial +from typing import Any, Callable, Dict, List, Type + +import numpy as np +from rdkit import Chem +from guacamol.common_scoring_functions import ( + IsomerScoringFunction, + RdkitScoringFunction, + SMARTSScoringFunction, + TanimotoScoringFunction, +) +from guacamol.score_modifier import ( + ClippedScoreModifier, + GaussianModifier, + MaxGaussianModifier, + MinGaussianModifier, +) +from guacamol.scoring_function import ScoringFunction +from guacamol.utils.descriptors import ( + bertz, + logP, + mol_weight, + num_aromatic_rings, + num_rings, + num_rotatable_bonds, + qed, + tpsa, +) + +MODIFIERS: Dict[str, Callable[..., Any]] = { + "gaussian_modifier": GaussianModifier, + "min_gaussian_modifier": MinGaussianModifier, + "max_gaussian_modifier": MaxGaussianModifier, + "clipped_score_modifier": ClippedScoreModifier, +} +MODIFIERS_PARAMETERS: Dict[str, Dict[str, float]] = { + "gaussian_modifier": {"mu": 2, "sigma": 0.5}, + "min_gaussian_modifier": {"mu": 0.75, "sigma": 0.1}, + "max_gaussian_modifier": {"mu": 100, "sigma": 10}, + "clipped_score_modifier": {"upper_x": 0.8}, +} +DESCRIPTOR: Dict[str, Callable[..., Any]] = { + "num_rotatable_bonds": num_rotatable_bonds, + "num_aromatic_rings": num_aromatic_rings, + "log_p": logP, + "tpsa": tpsa, + "bertz": bertz, + "qed": qed, + "mol_weight": mol_weight, + "num_rings": num_rings, +} + + +def distance_to_score(distance: float, beta: float) -> float: + """calculating exponential for a given distance + + Args: + distance: A float. + + Returns: + An exponential score value for a given SMILES + """ + return np.exp(-beta * distance ** 2) + + +class DistanceScorer(ScoringFunction): + def __init__(self, beta: float = 0.00000001) -> None: + """DistanceScorer is used to call a partial copy of distance_to_score function. + + Args: + beta: A float value used for getting an exponential score value + """ + self.partical_distance_score = partial(distance_to_score, beta=beta) + + def get_distance(self, smile_distance: float) -> float: + """Generates a partial copy of distance_to_score function + + Args: + smiles: SMILES. + + Returns: + An exponential score value for a given SMILES + """ + return self.partical_distance_score(smile_distance) + + +class TargetValueScorer(DistanceScorer): + def __init__(self, target: float, scoring_function: Callable[[str], float]) -> None: + """Scoring function which is used to generate a socre based on a taget and a scroing function. + + Args: + target: target score that will be used to get the distance to the score of the SMILES + scoring_function: an instance of a scoring class + """ + super().__init__() + self.target = target + self.scoring_function = scoring_function + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + return self.get_distance(self.scoring_function(smiles) - self.target) + + def score_list(self, smiles_list: List[str]) -> List[float]: + """Generates a list of scores for a given SMILES List + + Args: + smiles_list: A List of SMILES. + + Returns: + A List of scores + """ + return [ + self.score(smiles) + for smiles in smiles_list + if Chem.MolFromSmiles(smiles) and smiles + ] + + +class CombinedScorer: + def __init__( + self, + scorer_list: List[Type[Any]], + weights: List[float] = None, + ) -> None: + """Scoring function which generates a combined score for a SMILES as per the given scoring functions. + + Args: + scorer_list: A list of the scoring functions + weights: A list of weights + """ + self.scorer_list = scorer_list + self.weights = self._normalize_weights(weights) + + def _normalize_weights(self, weights=None) -> List[float]: + """It is used for normalizing weights. + + Args: + weights: A list of weights. + + Returns: + Sum of all the scores generated by the given scoring functions + """ + weights = weights if weights else [1.0] * len(self.scorer_list) + offsetted_weights = [weight + min(weights) for weight in weights] + return [weight / float(sum(offsetted_weights)) for weight in offsetted_weights] + + def score(self, smiles: str): + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + Sum of all the scores generated by the given scoring functions + """ + return sum( + [ + scorer.score(smiles) * weight + for scorer, weight in zip(self.scorer_list, self.weights) + ] + ) + + def score_list(self, smiles_list: List[str]) -> List[float]: + """Generates a list of scores for a given SMILES List + + Args: + smiles_list: A List of SMILES. + + Returns: + A List of scores + """ + return [self.score(smiles) for smiles in smiles_list] + + +class RDKitDescriptorScorer(TargetValueScorer): + def __init__( + self, + target: float, + modifier: str = "gaussian_modifier", + descriptor: str = "num_rotatable_bonds", + ) -> None: + """Scoring function wrapping RDKit descriptors. + + Args: + target: target score that will be used to get the distance to the score of the SMILES + modifier: score modifier + descriptor: molecular descriptors + """ + self.target = target + self.modifier = MODIFIERS[modifier](**MODIFIERS_PARAMETERS[modifier]) + self.descriptor = DESCRIPTOR[descriptor] + super().__init__(target=target, scoring_function=self.score) + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + scoring_function = RdkitScoringFunction( + descriptor=self.descriptor, + score_modifier=self.modifier, + ) + return scoring_function.score_mol(Chem.MolFromSmiles(smiles)) + + +class TanimotoScorer(TargetValueScorer): + def __init__( + self, + target: float, + target_smile: str, + fp_type: str = "ECFP4", + modifier: str = "gaussian_modifier", + ) -> None: + """Scoring function that looks at the fingerprint similarity against a target molecule. + + Args: + target: target score that will be used to get the distance to the score of the SMILES + target_smile: target molecule to compare similarity + fp_type: fingerprint type + modifier: score modifier + """ + self.target = target + self.target_smile = target_smile + self.fp_type = fp_type + self.modifier = MODIFIERS[modifier](**MODIFIERS_PARAMETERS[modifier]) + super().__init__(target=target, scoring_function=self.score) + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + scoring_function = TanimotoScoringFunction( + self.target_smile, + fp_type=self.fp_type, + score_modifier=self.modifier, + ) + return scoring_function.score_mol(Chem.MolFromSmiles(smiles)) + + +class IsomerScorer(TargetValueScorer): + def __init__(self, target: float, target_smile: str) -> None: + """Scoring function for closeness to a molecular formula. + + Args: + target: target score that will be used to get the distance to the score of the SMILES + target_smile: targeted SMILES to compare closeness with + """ + self.target = target + self.target_smile = target_smile + super().__init__(target=target, scoring_function=self.score) + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + scoring_function = IsomerScoringFunction(self.target_smile) + return scoring_function.raw_score(smiles) + + +class SMARTSScorer(TargetValueScorer): + def __init__(self, target: float, target_smile: str, inverse: bool = True) -> None: + """Scoring function that looks at the fingerprint similarity against a target molecule. + + Args: + target: target score that will be used to get the distance to the score of the SMILES + target_smile: The SMARTS string to match + inverse: If True then SMARTS is desired else it is not desired in the molecules + """ + self.target = target + self.target_smile = target_smile + self.inverse = inverse + super().__init__(target=target, scoring_function=self.score) + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + scoring_function = SMARTSScoringFunction(self.target_smile, self.inverse) + return scoring_function.score_mol(Chem.MolFromSmiles(smiles)) + + +class QEDScorer(TargetValueScorer): + def __init__(self, target: float) -> None: + """Scoring function that calculates the weighted sum of ADS mapped properties using QED module of rdkit + + Args: + target: target score that will be used to get the distance to the score of the SMILES + """ + self.target = target + super().__init__(target=target, scoring_function=self.score) + + def score(self, smiles: str) -> float: + """Generates a score for a given SMILES + + Args: + smiles: SMILES. + + Returns: + A score for the given SMILES + """ + return Chem.QED.qed(Chem.MolFromSmiles(smiles)) + + +SCORING_FUNCTIONS = { + "rdkit_scorer": RDKitDescriptorScorer, + "tanimoto_scorer": TanimotoScorer, + "isomer_scorer": IsomerScorer, + "smarts_scorer": SMARTSScorer, + "qed_scorer": QEDScorer, +} diff --git a/src/gt4sd/exceptions.py b/src/gt4sd/exceptions.py new file mode 100644 index 000000000..70dad77ea --- /dev/null +++ b/src/gt4sd/exceptions.py @@ -0,0 +1,81 @@ +"""Custom exception definitions.""" + + +class S3SyncError(RuntimeError): + """Error in syncing the cache with S3.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize S3SyncError. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "S3SyncError" + self.title = title + self.detail = detail + super().__init__(detail) + + +class InvalidItem(ValueError): + """Error in validating an item.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize InvalidItem. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "InvalidItem" + self.title = title + self.detail = detail + super().__init__(detail) + + +class InvalidAlgorithmConfiguration(ValueError): + """Error in validating an algorithm configuration.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize InvalidAlgorithmConfiguration. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "InvalidAlgorithmConfiguration" + self.title = title + self.detail = detail + super().__init__(detail) + + +class DuplicateApplicationRegistration(ValueError): + """Error when identifier for a registration is not unique.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize DuplicateApplicationRegistration. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "InvalidAlgorithmConfiguration" + self.title = title + self.detail = detail + super().__init__(detail) + + +class SamplingError(TimeoutError): + """Error when inference takes too long.""" + + def __init__(self, title: str, detail: str) -> None: + """Initialize SamplingError. + + Args: + title: title of the error. + detail: description of the error. + """ + self.type = "SamplingError" + self.title = title + self.detail = detail + super().__init__(detail) diff --git a/src/gt4sd/extras/__init__.py b/src/gt4sd/extras/__init__.py new file mode 100644 index 000000000..fddd6bbc9 --- /dev/null +++ b/src/gt4sd/extras/__init__.py @@ -0,0 +1,11 @@ +"""Extras handling.""" +# extras requirements +EXTRAS_ENABLED: bool +try: + import AMD_Analytics # noqa: F401 + import cog # noqa: F401 + import pag # noqa: F401 + + EXTRAS_ENABLED = True +except ImportError: + EXTRAS_ENABLED = False diff --git a/src/gt4sd/frameworks/__init__.py b/src/gt4sd/frameworks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/frameworks/enzeptional/__init__.py b/src/gt4sd/frameworks/enzeptional/__init__.py new file mode 100644 index 000000000..b823809bb --- /dev/null +++ b/src/gt4sd/frameworks/enzeptional/__init__.py @@ -0,0 +1,6 @@ +"""enzeptional - ENZymE OPTImizatiON for biocatALysis. + +Module for enzyme optimization. +""" + +from .optimization import EnzymeOptimizer # noqa: F401 diff --git a/src/gt4sd/frameworks/enzeptional/optimization.py b/src/gt4sd/frameworks/enzeptional/optimization.py new file mode 100644 index 000000000..10da183fe --- /dev/null +++ b/src/gt4sd/frameworks/enzeptional/optimization.py @@ -0,0 +1,422 @@ +"""Enzyme optimization.""" + +import json +import random +import time +from collections import OrderedDict +from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple, Union + +import numpy as np +from joblib import load +from loguru import logger + +from .processing import ( + HuggingFaceTransformerEmbedding, + StringEmbedding, + TAPEEmbedding, + reconstruct_sequence_with_mutation_range, + sanitize_intervals, +) + +#: transition matrix representation +TransitionMatrix = MutableMapping[str, MutableMapping[str, float]] +#: transition matrix configuration +TransitionConfiguration = MutableMapping[ + str, Union[MutableMapping[str, float], Sequence[str]] +] + +#: supported features +SUPPORTED_FEATURE_SET = set(["substrate", "product", "sequence"]) + +#: IUPAC code mapping +IUPAC_CODES = OrderedDict( + [ + ("Ala", "A"), + ("Asx", "B"), # Aspartate or Asparagine + ("Cys", "C"), + ("Asp", "D"), + ("Glu", "E"), + ("Phe", "F"), + ("Gly", "G"), + ("His", "H"), + ("Ile", "I"), + ("Lys", "K"), + ("Leu", "L"), + ("Met", "M"), + ("Asn", "N"), + ("Pyl", "O"), # Pyrrolysin + ("Pro", "P"), + ("Gln", "Q"), + ("Arg", "R"), + ("Ser", "S"), + ("Thr", "T"), + ("Sec", "U"), # Selenocysteine + ("Val", "V"), + ("Trp", "W"), + ("Xaa", "X"), # Any AA + ("Tyr", "Y"), + ("Glx", "Z"), # Glutamate or Glutamine + ] +) +#: IUPAC character set +IUPAC_CHARACTER_SET = set(IUPAC_CODES.values()) +#: IUPAC uniform mutation mapping, we exclude 'X' from the mapping values because it denotes a generic AA +IUPAC_MUTATION_MAPPING: TransitionConfiguration = { + iupac_character: sorted(list(IUPAC_CHARACTER_SET - {iupac_character, "X"})) + for iupac_character in IUPAC_CHARACTER_SET +} + + +class Mutations: + """Mutations definition class.""" + + def __init__(self, transition_configuration: TransitionConfiguration) -> None: + """Generate the mutation given the configuration for the transitions. + + Args: + transition_configuration: transition configuration. + """ + self.transition_matrix = Mutations.transition_configuration_to_matrix( + transition_configuration + ) + + @staticmethod + def transition_configuration_to_matrix( + transition_configuration: TransitionConfiguration, + ) -> TransitionMatrix: + """Transform a configuration into a valid transition matrix. + + Args: + transition_configuration: transition configuration. + + Returns: + a transition matrix. + """ + transition_matrix: TransitionMatrix = dict() + for transition_source, transition_targets in transition_configuration.items(): + if isinstance(transition_targets, dict): + total = float(sum(transition_targets.values())) + transition_matrix[transition_source] = { + transition_target: transtion_element / total + for transition_target, transtion_element in transition_targets.items() + } + else: + transition_matrix[transition_source] = { + transition_target: 1 / len(transition_targets) + for transition_target in transition_targets + } + return transition_matrix + + @staticmethod + def from_json(filepath: str) -> "Mutations": + """Parse the mutation from a JSON containing the transition configuration. + + Returns: + the mutations object. + """ + with open(filepath) as fp: + return Mutations(json.load(fp)) + + def mutate(self, source: str) -> str: + """Mutate a source string. + + Args: + source: source string. + + Returns: + the mutated target. + """ + targets, probabilities = zip(*self.transition_matrix[source].items()) + return np.random.choice(targets, size=1, p=probabilities).item() + + +class AASequence: + def __init__( + self, sequence: str, mutations: Mutations = Mutations(IUPAC_MUTATION_MAPPING) + ) -> None: + """Initialize an AA sequence representation. + + Args: + sequence: AA sequence. + mutations: mutations definition. Defaults to uniform sampling of IUPAC AAs. + """ + self.sequence = sequence + self.sequence_length = len(sequence) + self.mutations = mutations + + def mutate(self, maximum_number_of_mutations: int) -> str: + """Mutate the sequence in multiple positions. + + Args: + maximum_number_of_mutations: maximum number of mutations. + + Returns: + the mutated sequence. + """ + if maximum_number_of_mutations > self.sequence_length: + logger.warning( + f"resetting maximum number of mutations ({maximum_number_of_mutations}), since it is higher than sequence length: {self.sequence_length}" + ) + maximum_number_of_mutations = self.sequence_length + if maximum_number_of_mutations < 1: + logger.warning( + f"maximum number of mutations can't be lower than 1 ({maximum_number_of_mutations}), resetting to 1" + ) + maximum_number_of_mutations = 1 + number_of_mutations = random.randint(1, maximum_number_of_mutations) + positions = sorted( + random.sample(range(self.sequence_length), number_of_mutations) + ) + mutated_sequence = "" + start_position = -1 + for position in positions: + mutated_sequence += self.sequence[(start_position + 1) : position] + mutated_sequence += self.mutations.mutate(self.sequence[position]) + start_position = position + mutated_sequence += self.sequence[(start_position + 1) :] + return mutated_sequence + + +class EnzymeOptimizer: + """Optimize an enzyme to catalyze a reaction from substrate to product.""" + + def __init__( + self, + scorer_filepath: str, + substrate: str, + product: str, + sequence: str, + protein_embedding: StringEmbedding = TAPEEmbedding(), + molecule_embedding: StringEmbedding = HuggingFaceTransformerEmbedding(), + ordering: List[str] = ["substrate", "product", "sequence"], + ) -> None: + """Initialize the enzyme designer. + + Args: + scorer_filepath: pickled scorer filepath. + substrate: substrate SMILES. + product: product SMILES. + sequence: AA sequence representing the enzyme to optimize. + protein_embedding: protein embedding class. Defaults to TAPE bert-base. + molecule_embedding: molecule embedding class. Defaults to ChemBERTa version 1. + ordering: ordering of the features for the scorer. Defaults to ["substrate", "product", "sequence"]. + + Raises: + ValueError: ordering provided is not feasible. + + Example: + An example optimizing a specific reaction:: + + filepath = f"/path/to/model/scoring_model.pkl" + substrate = "NC1=CC=C(N)C=C1" + product = "CNC1=CC=C(NC(=O)C2=CC=C(C=C2)C(C)=O)C=C1" + sequence = ( + "MSIQIKQSTMVRPAEETPNKSLWLSNIDMILRTPYSHTGAVLIYKQPDNNEDNIHPSSSMYFDANILIEALSKA" + "LVPFYPMAGRLKINGDRYEIDCNAEGALFVEAESSHVLEDFGDFRPNDELHRVMVPTCDYSKGISSFPLLMVQLT" + "RFRCGGVSIGFAQHHHVCDGMAHFEFNNSWARIAKGLLPALEPVHDRYLHLRPRNPPQIKYSHSQFEPFVPSLPN" + "ELLDGKTNKSQTLFILSREQINTLKQKLDLSNNTTRLSTYEVVAAHVWRSVSKARGLSDHEEIKLIMPVDGRSRIN" + "NPSLPKGYCGNVVFLAVCTATVGDLSCNPLTDTAGKVQEALKGLDDDYLRSAIDHTESKPGLPVPYMGSPEKTLYPN" + "VLVNSWGRIPYQAMDFGWGSPTFFGISNIFYDGQCFLIPSRDGDGSMTLAINLFSSHLSRFKKYFYDF" + ) + # instantiate the designer + designer = EnzymeOptimizer( + scorer_filepath=filepath, substrate=substrate, product=product, sequence=sequence + ) + + + # with this sequence length every steps takes ~5s + # optimize between positions 150 and 405 allowing for a maximum of 5 mutations. + results = designer.optimize( + number_of_mutations=5, number_of_steps=10, number_of_samples_per_step=8, + intervals=[(150, 405)] + ) + best_score = results[0]["score"] + best_sequence = results[0]["sequence"] + """ + if len(set(ordering).intersection(SUPPORTED_FEATURE_SET)) < 3: + raise ValueError( + f"ordering={ordering} should contain only the three admissible values: {sorted(list(SUPPORTED_FEATURE_SET))}" + ) + else: + self._ordering = ordering + self.scorer_filepath = scorer_filepath + self.scorer = load(scorer_filepath) + self.substrate = substrate + self.product = product + self.protein_embedding = protein_embedding + self.molecule_embedding = molecule_embedding + self.embedded_vectors = { + "substrate": self.molecule_embedding.embed_one(self.substrate), + "product": self.molecule_embedding.embed_one(self.product), + } + self.sequence = sequence + self.sequence_length = len(sequence) + + def score_sequence(self, sequence: str) -> float: + """Score a given sequence. + + Args: + sequence: a sequence to score. + + Returns: + score for the sequence. + """ + embedded_vectors = {"sequence": self.protein_embedding.embed_one(sequence)} + embedded_vectors.update(self.embedded_vectors) + feature_vector = np.concatenate( + [embedded_vectors[feature] for feature in self._ordering], axis=1 + ) + return self.scorer.predict_proba(feature_vector)[0][1] + + def score_sequences(self, sequences: List[str]) -> List[Dict[str, Any]]: + """Score a given sequence list. + + Args: + sequences: a list of sequences to score. + + Returns: + a list of dictionaries of sequences and related scores. + """ + number_of_sequences = len(sequences) + embedded_matrices = { + "substrate": np.repeat( + self.embedded_vectors["substrate"], number_of_sequences, axis=0 + ), + "product": np.repeat( + self.embedded_vectors["product"], number_of_sequences, axis=0 + ), + } + embedded_matrices["sequence"] = self.protein_embedding(sequences) + feature_vector = np.concatenate( + [embedded_matrices[feature] for feature in self._ordering], axis=1 + ) + return [ + {"sequence": sequence, "score": score} + for sequence, score in zip( + sequences, self.scorer.predict_proba(feature_vector)[:, 1] + ) + ] + + def optimize( + self, + number_of_mutations: int, + intervals: Optional[List[Tuple[int, int]]] = None, + number_of_steps: int = 10, + number_of_samples_per_step: int = 32, + number_of_sequences: Optional[int] = None, + seed: int = 42, + time_budget: Optional[int] = None, + mutations: Mutations = Mutations(IUPAC_MUTATION_MAPPING), + ) -> List[Dict[str, Any]]: + """Optimize the enzyme given a number of mutations and a range. + + If the range limits are not provided the full sequence is optimized, this might be inefficient. + The sampling is performing by exploring mutations with a slightly smart random sampling. + + Args: + number_of_mutations: number of allowed mutations. + intervals: list of ranges in the sequence, zero-based. Defaults to None, a.k.a. use optimize the full sequence. + number_of_steps: number of optimization steps. Defaults to 100. + number_of_samples_per_step: number of samples sequences per optimization step. Defaults to 32. + number_of_sequences: number of optimal seuqence returned. Defaults to None, a.k.a, returns all. + seed: seed for random number generation. Defaults to 42. + time_budget: maximum allowed runtime in seconds. Defaults to None, a.k.a., no time limit, running for number_of_steps steps. + mutations: mutations definition. Defaults to uniform sampling of IUPAC AAs. + + Raises: + ValueError: in case an invalid range is provided. + + Returns: + a list of dictionaries containing a candidate optimal sequence and the related score. Sorted from best to worst. + Note that, when no limit on the returned number of sequences is set, the worst sequence is the original unmutated sequence. + If the optimization fails, only the original sequence is returned. + """ + random.seed(seed) + + # check if interval is None. In case it is, take as interval the whole sequence + if intervals is None: + intervals = [(0, self.sequence_length)] + else: + intervals = sanitize_intervals( + intervals + ) # here we merged and sorted the intervals + + # check that the intervals are in the range of the sequence length + if intervals[-1][1] > self.sequence_length: + raise ValueError( + "check provided intervals, at least an interval is larger than the sequence length" + ) + + # create a sequence from based on the intervals + sequence_from_intervals = "".join( + [self.sequence[start:end] for start, end in intervals] + ) + + # mutate the sequence from intervals + aa_sequence_range = AASequence(sequence_from_intervals, mutations=mutations) + maximum_number_of_mutations = number_of_mutations + + logger.info( + f"maximum number of mutations for the intervals: {maximum_number_of_mutations}" + ) + scored_original_sequence = { + "score": self.score_sequence(self.sequence), + "sequence": self.sequence, + } + original_sequence_score = scored_original_sequence["score"] + logger.info(f"original sequence score: {original_sequence_score}") + results: List[Dict[str, Any]] = [scored_original_sequence] + # slightly smart random sampling + visited_sequences = set() + start_time = time.time() + for step in range(number_of_steps): + logger.info(f"optimization step={step + 1}") + mutated_sequences = [] + + for _ in range(number_of_samples_per_step): + mutated_sequence_range = aa_sequence_range.mutate( + maximum_number_of_mutations=maximum_number_of_mutations + ) + + mutated_sequence = reconstruct_sequence_with_mutation_range( + sequence=self.sequence, + mutated_sequence_range=mutated_sequence_range, + intervals=intervals, + ) + + # make sure we do not revisit + if mutated_sequence not in visited_sequences: + visited_sequences.add(mutated_sequence) + mutated_sequences.append(mutated_sequence) + + # add only mutated sequences that are more optimal than the original + results += [ + scored_sequence + for scored_sequence in self.score_sequences(mutated_sequences) + if scored_sequence["score"] > original_sequence_score + ] + logger.info( + f"best score at step={step + 1}: {max([scored_sequence['score'] for scored_sequence in results])}" + ) + elapsed_time = int(time.time() - start_time) + if time_budget is not None: + if elapsed_time > time_budget: + logger.warning( + f"used all the given time budget of {time_budget}s, exting optimization loop" + ) + break + logger.info( + f"optimization completed visiting {len(visited_sequences)} mutated sequences" + ) + sorted_results = sorted( + results, key=lambda result: result["score"], reverse=True + )[:number_of_sequences] + if len(sorted_results) < 2: + logger.error( + "optimization failed, could not find a mutated sequence more optimal than the original" + ) + else: + logger.info( + f"found {len(sorted_results) - 1} optimal mutated sequences, best score: {sorted_results[0]['score']}" + ) + return sorted_results diff --git a/src/gt4sd/frameworks/enzeptional/processing.py b/src/gt4sd/frameworks/enzeptional/processing.py new file mode 100644 index 000000000..c64e8a59a --- /dev/null +++ b/src/gt4sd/frameworks/enzeptional/processing.py @@ -0,0 +1,234 @@ +"""enzeptional - data processing utilities.""" + +from abc import ABC +from typing import Generic, List, Optional, Tuple, TypeVar, Union + +import numpy as np +import torch +from tape.datasets import pad_sequences +from tape.registry import registry +from tape.tokenizers import TAPETokenizer +from transformers import AutoModelWithLMHead, AutoTokenizer + +from ..torch import device_claim + +T = TypeVar("T") # used for sample embedding + + +class Embedding(ABC, Generic[T]): + """Abstract embedding class.""" + + def embed_one(self, sample: T) -> np.ndarray: + """Embed one sample. + + Args: + sample: sample representation. + + Returns: + embedding vector for the sample. + """ + return self.__call__([sample]) + + def __call__(self, samples: List[T]) -> np.ndarray: + """Embed multiple samples sample. + + Args: + samples: a list of sample representations. + + Returns: + embedding vectors for the samples. + """ + raise NotImplementedError + + +StringEmbedding = Embedding[str] + + +class TAPEEmbedding(StringEmbedding): + """Embed AA sequence using TAPE.""" + + def __init__( + self, + model_type: str = "transformer", + model_dir: str = "bert-base", + aa_vocabulary: str = "iupac", + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """Initialize the TAPE embedding class. + + Args: + model_type: TAPE model type. Defaults to "transformer". + model_dir: model directory. Defaults to "bert-base". + aa_vocabulary: type of vocabulary. Defaults to "iupac". + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + # get device + self.device = device_claim(device) + # task and model definition + self.task_specification = registry.get_task_spec("embed") + self.model = registry.get_task_model( + model_type, self.task_specification.name, load_dir=model_dir + ) + self.model = self.model.to(self.device) + self.model.eval() + self.tokenizer = TAPETokenizer(vocab=aa_vocabulary) + + def _encode_and_mask(self, sequence: str) -> Tuple[np.ndarray, np.ndarray]: + """Encode and mask a sequence. + + Args: + sequence: AA sequence. + + Returns: + a tuple containing the token ids and the mask. + """ + token_ids = self.tokenizer.encode(sequence) + return token_ids, np.ones_like(token_ids) + + def __call__(self, samples: List[str]) -> np.ndarray: + """Embed multiple protein sequences using TAPE. + + Args: + samples: a list of protein sequences. + + Returns: + a numpy array containing the embedding vectors. + """ + # prepare input + token_ids, masks = zip( + *[self._encode_and_mask(sequence) for sequence in samples] + ) + input_data = { + "input_ids": torch.from_numpy(pad_sequences(token_ids)).to(self.device), + "input_mask": torch.from_numpy(pad_sequences(masks)).to(self.device), + } + sequence_lenghts = input_data["input_mask"].sum(1) + sequence_embeddings = self.model(**input_data)[0].cpu().detach().numpy() + # get average embedding + return np.array( + [ + sequence_embedding[:sequence_length].mean(0) + for sequence_embedding, sequence_length in zip( # type:ignore + sequence_embeddings, sequence_lenghts + ) + ] + ) + + +class HuggingFaceTransformerEmbedding(StringEmbedding): + """Embed a string representation of a molecule using an HF transformers model.""" + + def __init__( + self, + model_name: str = "seyonec/ChemBERTa-zinc-base-v1", + tokenizer_name: str = "seyonec/ChemBERTa-zinc-base-v1", + device: Optional[Union[torch.device, str]] = None, + ) -> None: + """Initialize the HF transformers embedding class. + + Args: + model_name: model name. Defaults to "seyonec/ChemBERTa-zinc-base-v1". + tokenizer_name: tokenizer name. Defaults to "seyonec/ChemBERTa-zinc-base-v1". + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + """ + # get device + self.device = device_claim(device) + # tokenizer and model definition + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelWithLMHead.from_pretrained(tokenizer_name) + self.model = self.model.to(self.device) + self.model.eval() + + def __call__(self, samples: List[str]) -> np.ndarray: + """Embed multiple protein sequences using TAPE. + + Args: + samples: a list of strings representing molecules. + + Returns: + a numpy array containing the embedding vectors. + """ + # get the CLS token representation from each SMILES. + return ( + self.model( + **{ + key: tensor.to(self.device) + for key, tensor in self.tokenizer( + samples, return_tensors="pt", padding=True + ).items() + } + )[0][:, 0, :] + .detach() + .numpy() + ) + + +def mutate_sequence_with_variant(sequence: str, variant: str) -> str: + """Given an AA sequence and a variant returns the mutated AA sequence. + + Args: + sequence: an AA sequence. + variant: a variant annotation. + + Returns: + the mutated sequence. + """ + edits = [ + (int(variant_string[1:-1]), variant[0], variant_string[-1]) + for variant_string in map(str.strip, variant.split("/")) + ] + mutated_sequence = list(sequence) + for index, _, aa_to in edits: + mutated_sequence[index] = aa_to + return "".join(mutated_sequence) + + +def sanitize_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """Sanitize intervals merging overlapping ones and sorting them. + + Args: + intervals: intervals to sanitize. + + Returns: + sorted and non overlapping intervals. + """ + sorted_intervals = sorted(intervals, key=lambda interval: interval[0]) + merged_intervals = [sorted_intervals[0]] + for current in sorted_intervals: + previous_end = merged_intervals[-1][1] + if current[0] <= previous_end: + previous_end = max(previous_end, current[1]) + else: + merged_intervals.append(current) + return merged_intervals + + +def reconstruct_sequence_with_mutation_range( + sequence: str, mutated_sequence_range: str, intervals: List[Tuple[int, int]] +): + """Reconstruct a sequence replacing in given positions sub-sequences from a mutated range. + + Args: + sequence: original sequence. + mutated_sequence_range: mutated sequence range. + intervals: sorted and non overlapping intervals. + + Returns: + reconstructed sequence. + """ + # create the mutated sequence, considering sorted intervals + mutated_range_offset = 0 # offset with respect to the mutated_sequence_range + mutated_sequence_offset = 0 # offset with respect to the full mutated sequence. + mutated_sequence = "" + for start, end in intervals: + mutated_sequence += sequence[mutated_sequence_offset:start] + chunk_length = end - start + 1 + mutated_sequence += mutated_sequence_range[ + mutated_range_offset : mutated_range_offset + chunk_length + ] + mutated_range_offset += chunk_length + mutated_sequence_offset = end + 1 + mutated_sequence += sequence[end + 1 :] + return mutated_sequence diff --git a/src/gt4sd/frameworks/enzeptional/tests/__init__.py b/src/gt4sd/frameworks/enzeptional/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/frameworks/enzeptional/tests/test_processing.py b/src/gt4sd/frameworks/enzeptional/tests/test_processing.py new file mode 100644 index 000000000..419045b54 --- /dev/null +++ b/src/gt4sd/frameworks/enzeptional/tests/test_processing.py @@ -0,0 +1,22 @@ +"""Enzeptional processing tests.""" + +from gt4sd.frameworks.enzeptional.processing import ( + reconstruct_sequence_with_mutation_range, + sanitize_intervals, +) + + +def test_sanitize_intervals(): + assert sanitize_intervals([(-5, 12), (13, 14), (2, 3), (-3, 4), (-2, 6)]) == [ + (-5, 12), + (13, 14), + ] + + +def test_reconstruct_sequence_with_mutation_range(): + assert ( + reconstruct_sequence_with_mutation_range( + "ABCDEFGHILMNOPQRSTUVWXYZ", "12789", [(0, 1), (6, 8)] + ) + == "12CDEF789LMNOPQRSTUVWXYZ" + ) diff --git a/src/gt4sd/frameworks/granular/__init__.py b/src/gt4sd/frameworks/granular/__init__.py new file mode 100644 index 000000000..dee2f3851 --- /dev/null +++ b/src/gt4sd/frameworks/granular/__init__.py @@ -0,0 +1 @@ +"""granular - GeneRative AutoeNcoders mULtimodAl Representations.""" diff --git a/src/gt4sd/frameworks/granular/arg_parser/__init__.py b/src/gt4sd/frameworks/granular/arg_parser/__init__.py new file mode 100644 index 000000000..3e4483dc1 --- /dev/null +++ b/src/gt4sd/frameworks/granular/arg_parser/__init__.py @@ -0,0 +1 @@ +"""Arguments parser module.""" diff --git a/src/gt4sd/frameworks/granular/arg_parser/parser.py b/src/gt4sd/frameworks/granular/arg_parser/parser.py new file mode 100644 index 000000000..16faec2a8 --- /dev/null +++ b/src/gt4sd/frameworks/granular/arg_parser/parser.py @@ -0,0 +1,96 @@ +import argparse +import configparser +from typing import Any, Dict, Optional + +from pytorch_lightning import Trainer + +from ..ml.models import ARCHITECTURE_FACTORY +from .utils import convert_string_to_class + + +def parse_arguments_from_config(conf_file: Optional[str] = None) -> argparse.Namespace: + """Parse arguments from configuration file. + + Args: + conf_file: configuration file. Defaults to None, a.k.a. us a default configuration + in ./config/config.ini. + + Returns: + the parsed arguments. + """ + parser = argparse.ArgumentParser() + + # open config.ini file, either from parser or default file + parser.add_argument( + "--conf_file", + type=str, + help=("config file for the defaults value"), + default="./config/config.ini", + ) + + # Read config file + args, remaining_argv = parser.parse_known_args() + config = configparser.ConfigParser() + + if conf_file: + config.read(conf_file) + else: + config.read(args.conf_file) + + # classes that are not model name + general_config_classes = ["general", "trainer", "default"] + + # adding a list of all model name into the args + result: Dict[str, Any] = dict() + result["model_list"] = [ + i for i in list(config.keys()) if i.lower() not in general_config_classes + ] + for key in [*config.keys()]: + # go trough all models parameter, replace the parsed ones from the the config files ones + if key.lower() not in general_config_classes: + model_type = config[key]["type"] + params_from_configfile = dict(config[key]) + model = ARCHITECTURE_FACTORY[model_type.lower()] + parser = model.add_model_specific_args(parser, key) + args, _ = parser.parse_known_args() + args_dictionary = vars(args) + params_from_configfile["name"] = key + + for i in params_from_configfile: + params_from_configfile[i] = convert_string_to_class( + params_from_configfile[i] + ) + + params_from_configfile.update( + { + k[: -len(key) - 1]: v + for k, v in args_dictionary.items() + if v is not None and k.endswith("_" + key) + } + ) + + result[key] = params_from_configfile + + elif key.lower() == "trainer" or key.lower() == "general": + params_from_configfile = dict(config[key]) + for i in params_from_configfile: + params_from_configfile[i] = convert_string_to_class( + params_from_configfile[i] + ) + result.update(params_from_configfile) + + # parser Pytorch Trainer arguments + parser = Trainer.add_argparse_args(parser) + + # adding basename as the name of the run + parser.add_argument("--basename", type=str) + parser.add_argument("--batch_size", type=int) + parser.add_argument("--num_workers", type=int) + parser.add_argument("--lr", type=float) + parser.add_argument("--validation_split", type=float, default=None) + parser.add_argument("--validation_indices_file", type=str) + args_dictionary = vars(parser.parse_args(remaining_argv)) + result.update({k: v for k, v in args_dictionary.items() if v is not None}) + result_namespace = argparse.Namespace(**result) + + return result_namespace diff --git a/src/gt4sd/frameworks/granular/arg_parser/utils.py b/src/gt4sd/frameworks/granular/arg_parser/utils.py new file mode 100644 index 000000000..130df482f --- /dev/null +++ b/src/gt4sd/frameworks/granular/arg_parser/utils.py @@ -0,0 +1,51 @@ +"""Parsing utilties.""" + +import argparse +import ast +from typing import Union + + +def str2bool(s: Union[str, bool]) -> bool: + """Convert a string into a bool. + + Args: + s: a string representation of a boolean. + + Raises: + argparse.ArgumentTypeError: in case the conversion is failing. + + Returns: + the converted value. + """ + if isinstance(s, bool): + return s + if s.lower() in ("yes", "true", "t", "y", "1"): + return True + elif s.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def convert_string_to_class(s: str): + """Convert a string into a python object. + + Fallback to ast in case of unexpected strings. + + Args: + s: a string. + + Returns: + the converted python object. + """ + if s.lower() == "true": + return True + elif s.lower() == "false": + return False + elif s.lower() == "none": + return None + elif s: + try: + return ast.literal_eval(s) + except (ValueError, SyntaxError): + return s diff --git a/src/gt4sd/frameworks/granular/dataloader/__init__.py b/src/gt4sd/frameworks/granular/dataloader/__init__.py new file mode 100644 index 000000000..f179383d7 --- /dev/null +++ b/src/gt4sd/frameworks/granular/dataloader/__init__.py @@ -0,0 +1 @@ +"""Data loader module.""" diff --git a/src/gt4sd/frameworks/granular/dataloader/data_module.py b/src/gt4sd/frameworks/granular/dataloader/data_module.py new file mode 100644 index 000000000..d9620af2e --- /dev/null +++ b/src/gt4sd/frameworks/granular/dataloader/data_module.py @@ -0,0 +1,200 @@ +"""Data module for granular.""" + +import logging +from typing import Callable, List, Optional, cast + +import pandas as pd +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader, Sampler, Subset, random_split + +from .dataset import CombinedGranularDataset, GranularDataset +from .sampler import StratifiedSampler + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class GranularDataModule(pl.LightningDataModule): + """Data module from granular.""" + + def __init__( + self, + dataset_list: List[GranularDataset], + validation_split: Optional[float] = None, + validation_indices_file: Optional[str] = None, + stratified_batch_file: Optional[str] = None, + stratified_value_name: Optional[str] = None, + batch_size: int = 64, + num_workers: int = 1, + ) -> None: + """Construct GranularDataModule. + + Args: + dataset_list: a list of granular datasets. + validation_split: proportion used for validation. Defaults to None, + a.k.a., use indices file if provided otherwise uses half of the data for validation. + validation_indices_file: indices to use for validation. Defaults to None, a.k.a., + use validation split proportion, if not provided uses half of the data for validation. + stratified_batch_file: stratified batch file for sampling. Defaults to None, a.k.a., + no stratified sampling. + stratified_value_name: stratified value name. Defaults to None, a.k.a., + no stratified sampling. Needed in case a stratified batch file is provided. + batch_size: batch size. Defaults to 64. + num_workers: number of workers. Defaults to 1. + """ + super().__init__() + self.batch_size = batch_size + self.validation_split = validation_split + self.validation_indices_file = validation_indices_file + self.dataset_list = dataset_list + self.num_workers = num_workers + self.stratified_batch_file = stratified_batch_file + self.stratified_value_name = stratified_value_name + self.prepare_train_data() + + @staticmethod + def combine_datasets( + dataset_list: List[GranularDataset], + ) -> CombinedGranularDataset: + """Combine granular datasets. + + Args: + dataset_list: a list of granular datasets. + + Returns: + a combined granular dataset. + """ + return CombinedGranularDataset( + [a_dataset.dataset for a_dataset in dataset_list] + ) + + def prepare_train_data(self) -> None: + """Prepare training dataset.""" + self.train_dataset = GranularDataModule.combine_datasets(self.dataset_list) + + def prepare_test_data(self, dataset_list: List[GranularDataset]) -> None: + """Prepare testing dataset. + + Args: + dataset_list: a list of granular datasets. + """ + self.test_dataset = GranularDataModule.combine_datasets(dataset_list) + + def setup(self, stage: Optional[str] = None) -> None: + """Setup the data module. + + Args: + stage: stage considered, unused. Defaults to None. + """ + if ( + self.stratified_batch_file is not None + and self.stratified_value_name is None + ): + raise ValueError( + f"stratified_batch_file={self.stratified_batch_file}, need to set stratified_value_name as well" + ) + if self.validation_indices_file is None and self.validation_split is None: + self.validation_split = 0.5 + if self.validation_indices_file: + val_indices = ( + pd.read_csv(self.validation_indices_file).values.flatten().tolist() + ) + train_indices = [ + i for i in range(len(self.train_dataset)) if i not in val_indices + ] + self.train_data = Subset(self.train_dataset, train_indices) + self.val_data = Subset(self.train_dataset, val_indices) + + else: + val = int(len(self.train_dataset) * cast(float, (self.validation_split))) + train = len(self.train_dataset) - val + self.train_data, self.val_data = random_split( + self.train_dataset, [train, val] + ) + logger.info(f"number of data points used for training: {len(self.train_data)}") + logger.info(f"number of data points used for validation: {len(self.val_data)}") + logger.info( + f"validation proportion: {len(self.val_data) / (len(self.val_data) + len(self.train_data))}" + ) + + @staticmethod + def get_stratified_batch_sampler( + stratified_batch_file: str, + stratified_value_name: str, + batch_size: int, + selector_fn: Callable[[pd.DataFrame], pd.DataFrame], + ) -> StratifiedSampler: + """Get stratified batch sampler. + + Args: + stratified_batch_file: stratified batch file for sampling. + stratified_value_name: stratified value name. + batch_size: batch size. + selector_fn: selector function for stratified sampling. + Returns: + a stratified batch sampler. + """ + stratified_batch_dataframe = pd.read_csv(stratified_batch_file) + stratified_data = stratified_batch_dataframe[ + selector_fn(stratified_batch_dataframe) + ][stratified_value_name].values + stratified_data_tensor = torch.from_numpy(stratified_data) + return StratifiedSampler(targets=stratified_data_tensor, batch_size=batch_size) + + def train_dataloader(self) -> DataLoader: + """Get a training data loader. + + Returns: + a training data loader. + """ + sampler: Optional[Sampler] = None + if self.stratified_batch_file: + sampler = GranularDataModule.get_stratified_batch_sampler( + stratified_batch_file=self.stratified_batch_file, + stratified_value_name=str(self.stratified_value_name), + batch_size=self.batch_size, + selector_fn=lambda dataframe: ~dataframe["validation"], + ) + return DataLoader( + self.train_data, + num_workers=self.num_workers, + batch_size=self.batch_size, + pin_memory=False, + sampler=sampler, + ) + + def val_dataloader(self) -> DataLoader: + """Get a validation data loader. + + Returns: + a validation data loader. + """ + sampler: Optional[Sampler] = None + if self.stratified_batch_file: + sampler = GranularDataModule.get_stratified_batch_sampler( + stratified_batch_file=self.stratified_batch_file, + stratified_value_name=str(self.stratified_value_name), + batch_size=self.batch_size, + selector_fn=lambda dataframe: dataframe["validation"], + ) + return DataLoader( + self.val_data, + num_workers=self.num_workers, + batch_size=self.batch_size, + pin_memory=False, + sampler=sampler, + ) + + def test_dataloader(self) -> DataLoader: + """Get a testing data loader. + + Returns: + a testing data loader. + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=False, + ) diff --git a/src/gt4sd/frameworks/granular/dataloader/dataset.py b/src/gt4sd/frameworks/granular/dataloader/dataset.py new file mode 100644 index 000000000..e7fe024d2 --- /dev/null +++ b/src/gt4sd/frameworks/granular/dataloader/dataset.py @@ -0,0 +1,522 @@ +"""Dataset module.""" + +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import pandas as pd +import torch +from sklearn.compose import ColumnTransformer +from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, StandardScaler +from torch.utils.data import Dataset + +from ..ml.models import ARCHITECTURE_FACTORY, AUTOENCODER_ARCHITECTURES +from ..tokenizer.tokenizer import TOKENIZER_FACTORY, Tokenizer + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +SCALING_FACTORY_FN: Dict[str, Callable] = { + "onehot": lambda: OneHotEncoder(handle_unknown="error", sparse=False), + "min-max": lambda: MinMaxScaler(), + "standard": lambda: StandardScaler(), +} +MODEL_TYPES = set(ARCHITECTURE_FACTORY.keys()) + + +class GranularDataset(Dataset): + """A dataset wrapper for granular""" + + def __init__(self, name: str, data: Dict[str, Any]) -> None: + """Initialize a granular dataset. + + Args: + name: dataset name. + data: dataset samples. + """ + self.dataset: Dict[str, Any] = {"name": name, "data": data} + + def __len__(self) -> int: + """Dataset length. + + Returns: + length of the dataset. + """ + lengths = {key: len(data) for key, data in self.dataset["data"].items()} + if len(set(lengths.values())) > 1: + raise ValueError(f"mismatching dimensions for the data: {lengths}") + return list(lengths.values())[0] + + def __getitem__(self, index: int) -> Dict[str, Any]: + """Retrieve an item from the dataset by index. + + Args: + index: index for the item. + + Returns: + an item. + """ + result = dict() + for key in self.dataset["data"]: + result[self.dataset["name"] + "_" + key] = self.dataset["data"][key][index] + return result + + +class CombinedGranularDataset(Dataset): + """General dataset combining multiple granular datasets.""" + + def __init__(self, datasets: List[Dict[str, Any]]) -> None: + """Initialize a general dataset. + + Args: + datasets: list of dataset configurations. + """ + self.datasets = datasets + self.names = [data["name"] for data in datasets] + + def __len__(self) -> int: + """Dataset length. + + Returns: + length of the dataset. + """ + return len([*self.datasets[0]["data"].values()][0]) + + def __getitem__(self, index: int) -> Dict[str, Any]: + """Retrieve an item from the dataset by index. + + Args: + index: index for the item. + + Returns: + an item. + """ + result = dict() + for dataset in self.datasets: + keys = [*dataset["data"]] + for key in keys: + result[dataset["name"] + "_" + key] = dataset["data"][key][index] + return result + + +class SmilesTokenizationPreProcessingDataset(GranularDataset): + """Dataset for SMILES/SELFIES preprocessing.""" + + def __init__( + self, + name: str, + data_columns: Dict[str, Any], + input_smiles: pd.DataFrame, + target_smiles: pd.DataFrame, + tokenizer: Tokenizer, + set_seq_size: Optional[int] = None, + ) -> None: + """Construct a SmilesTokenizationPreProcessingDataset. + + Args: + name: dataset name. + data_columns: data columns mapping. + input_smiles: dataframe containing input SMILES. + target_smiles: dataframe containing target SMILES. + tokenizer: a tokenizer defining the molecule representation used. + set_seq_size: sequence size. Defaults to None, a.k.a., define this + using the input SMILES. + """ + self.name = name + self.input_smiles = input_smiles.values.flatten().tolist() + self.target_smiles = target_smiles.values.flatten().tolist() + self.tokenizer = tokenizer + self.input_tokens: List[torch.Tensor] = [] + self.target_tokens: List[torch.Tensor] = [] + + tokens_ids = [ + tokenizer.convert_tokens_to_ids(tokenizer.tokenize(smile)) + for smile in self.input_smiles + ] + if set_seq_size: + self.set_seq_size = set_seq_size + else: + self.set_seq_size = max([len(i) for i in tokens_ids]) + 20 + + self.smiles_to_ids(input_smiles=self.input_smiles) + self.smiles_to_ids(target_smiles=self.target_smiles) + + super().__init__( + name=name, + data={ + data_columns["input"]: self.input_tokens, + data_columns["target"]: self.target_tokens, + }, + ) + + def smiles_to_ids( + self, input_smiles: List[str] = [], target_smiles: List[str] = [] + ) -> None: + """Process input SMILES lists generating examples by tokenizing strings and converting them to tensors. + + Args: + input_smiles: list of input SMILES representations. Defaults to []. + target_smiles: list of target SMILES representations. Defaults to []. + """ + if len(input_smiles) > 0 and len(target_smiles) == 0: + self.input_smiles = input_smiles + smiles = input_smiles + elif len(input_smiles) == 0 and len(target_smiles) > 0: + self.target_smiles = target_smiles + smiles = target_smiles + else: + raise Exception( + "Either input_smiles or target_smiles needs to be specified" + ) + + tokens_ids = [ + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(smile)) + for smile in smiles + ] + examples = [] + for token in tokens_ids: + example_tokens = self.tokenizer.convert_tokens_to_ids( + [self.tokenizer.sos_token] + ) + example_tokens.extend(token) + example_tokens.extend( + self.tokenizer.convert_tokens_to_ids([self.tokenizer.eos_token]) + ) + examples.append( + torch.tensor( + self.tokenizer.add_padding_tokens(example_tokens, self.set_seq_size) + ) + ) + + if len(input_smiles) > 0 and len(target_smiles) == 0: + self.input_tokens = examples + elif len(input_smiles) == 0 and len(target_smiles) > 0: + self.target_tokens = examples + + +class LatentModelDataset(GranularDataset): + """Latent model dataset.""" + + def __init__( + self, + name: str, + data_columns: Dict[str, Any], + target_data: pd.DataFrame, + scaling: Optional[str] = None, + ) -> None: + """Construct a LatentModelDataset. + + Args: + name: dataset name. + data_columns: data columns mapping. + target_data: dataframe for targets. + scaling: feature scaling process. Defaults to None, a.k.a. no scaling. Currently not supported. + + Raises: + NotImplementedError: in case a scaling is selected. + """ + self.name = name + if scaling: + raise NotImplementedError("Scaling not yet supported") + self.target_data = torch.from_numpy(target_data.values) + self.target_data = self.target_data.type(torch.float) + self.target_size = target_data.shape[1] + super().__init__(name=name, data={data_columns["target"]: self.target_data}) + + +class AutoEncoderDataset(GranularDataset): + """Autoencoder dataset.""" + + def __init__( + self, + name: str, + data_columns: Dict[str, Any], + input_data: pd.DataFrame, + target_data: pd.DataFrame, + scaling: Optional[str] = None, + ) -> None: + """Construct an AutoEncoderDataset. + + Args: + name: dataset name. + data_columns: data columns mapping. + input_data: dataframe for inputs. + target_data: dataframe for targets. + scaling: feature scaling process. Defaults to None, a.k.a. no scaling. Feasible values: "onehot", "min-max" and "standard". + + Raises: + ValueError: in case requested scaling is not supported. + """ + self.name = name + self.data_columns = data_columns + + if scaling is None: + self.input_data = torch.from_numpy(input_data.values) + self.target_data = torch.from_numpy(target_data.values) + else: + if scaling not in SCALING_FACTORY_FN: + raise ValueError( + f"Scaling={scaling} not supported. Pick a valid one: {sorted(list(SCALING_FACTORY_FN.keys()))}" + ) + + self.input_scaling = ColumnTransformer( + transformers=[ + ( + "InputScaling", + SCALING_FACTORY_FN[scaling](), + [data_columns["input"]], + ) + ] + ) + self.target_scaling = ColumnTransformer( + transformers=[ + ( + "TargetScaling", + SCALING_FACTORY_FN[scaling](), + [data_columns["target"]], + ) + ] + ) + + self.input_data = torch.from_numpy( + self.input_scaling.fit_transform(pd.concat([input_data], axis=1)) + ) + self.target_data = torch.from_numpy( + self.target_scaling.fit_transform(pd.concat([target_data], axis=1)) + ) + + self.input_data, self.target_data = ( + self.input_data.type(torch.float), + self.target_data.type(torch.float), + ) + self.input_size = self.input_data.shape[1] + self.target_size = self.target_data.shape[1] + + super().__init__( + name=name, + data={ + data_columns["input"]: self.input_data, + data_columns["target"]: self.target_data, + }, + ) + + +DATASET_FACTORY: Dict[str, Type[GranularDataset]] = { + "latentmodel": LatentModelDataset, + "smiles": SmilesTokenizationPreProcessingDataset, + "selfies": SmilesTokenizationPreProcessingDataset, + "autoencoder": AutoEncoderDataset, +} + + +def build_data_columns(hparams: Dict[str, Any]) -> Dict[str, Any]: + """Build data columns from hyper-parameters. + + Args: + hparams: hyper-parameters for the data columns. + + Returns: + data columns. + """ + try: + input_columns = hparams["input"] + except KeyError: + input_columns = None + try: + target_columns = hparams["target"] + except KeyError: + target_columns = None + # create dictionary + if input_columns: + data_columns = {"input": input_columns, "target": target_columns} + else: + data_columns = {"target": target_columns} + return data_columns + + +def build_dataset( + name: str, + data: pd.DataFrame, + dataset_type: str, + data_columns: Dict[str, Any], + hparams: Dict[str, Any], +) -> GranularDataset: + """Build a granular dataset. + + Args: + name: dataset name. + data: dataframe representing the dataset. + dataset_type: dataset type. Feasible values: "latentmodel", "smiles", "selfies" and "autoencoder". + data_columns: data columns mapping. + hparams: hyper-parameters for the data columns. + + Raises: + ValueError: in case requested dataset type is not supported. + + Returns: + a granular dataset. + """ + dataset: GranularDataset + dataset_type = dataset_type.lower() + if dataset_type not in DATASET_FACTORY: + raise ValueError( + f"dataset_type={dataset_type} not supported. Pick a valid one: {sorted(list(DATASET_FACTORY.keys()))}" + ) + + input_columns: List[Any] + if not dataset_type == "latentmodel": + if data_columns["input"] == "all": + input_columns = data.columns.tolist() + else: + if isinstance(data_columns["input"], list): + input_columns = data_columns["input"] + else: + input_columns = [data_columns["input"]] + + target_columns: List[Any] + if data_columns["target"] == "all": + target_columns = data.columns.tolist() + else: + if isinstance(data_columns["target"], list): + target_columns = data_columns["target"] + else: + target_columns = [data_columns["target"]] + + if dataset_type in {"smiles", "selfies"}: + try: + build_vocab = hparams["build_vocab"] + except KeyError: + build_vocab = None + try: + sequence_size = hparams["sequence_size"] + except KeyError: + sequence_size = None + vocab_file = hparams["vocab_file"] + + # build tokenizer + if build_vocab: + tokenizer = TOKENIZER_FACTORY[dataset_type]( + vocab_file, smiles=data[input_columns].squeeze().tolist() + ) + else: + tokenizer = TOKENIZER_FACTORY[dataset_type](vocab_file, smiles=[]) + dataset = SmilesTokenizationPreProcessingDataset( + name=name, + data_columns=data_columns, + input_smiles=data[input_columns], + target_smiles=data[target_columns], + tokenizer=tokenizer, + set_seq_size=sequence_size, + ) + elif dataset_type == "latentmodel": + dataset = LatentModelDataset( + name=name, + data_columns=data_columns, + target_data=data[target_columns], + scaling=None, + ) + elif dataset_type == "autoencoder": + dataset = AutoEncoderDataset( + name=name, + data_columns=data_columns, + input_data=data[input_columns], + target_data=data[target_columns], + scaling=hparams["scaling"], + ) + + return dataset + + +def build_architecture( + model_type: str, + data_columns: Dict[str, Any], + dataset: GranularDataset, + hparams: Dict[str, Any], +) -> Dict[str, Any]: + """Build architecture configuration for the selected model type and dataset. + + Args: + model_type: model type. Feasible values: "vae_rnn", "vae_trans", "mlp_predictor", "no_encoding", "mlp_autoencoder" and "vae_mlp". + data_columns: data columns mapping. + dataset: a granular dataset. + hparams: hyper-parameters for the data columns. + + Raises: + ValueError: in case requested model type is not supported. + + Returns: + architecture configuration. + """ + model_type = model_type.lower() + if model_type not in MODEL_TYPES: + raise ValueError( + f"model_type={model_type} not supported. Pick a valid one: {sorted(list(MODEL_TYPES))}" + ) + + architecture: Dict[str, Any] = { + "name": hparams["name"], + "type": hparams["type"], + "start_from_checkpoint": hparams["start_from_checkpoint"], + "freeze_weights": hparams["freeze_weights"], + "data": data_columns, + "hparams": hparams, + } + + if model_type in AUTOENCODER_ARCHITECTURES: + architecture["position"] = hparams["position"] + if model_type in {"vae_rnn", "vae_trans"}: + hparams["tokenizer"] = dataset.tokenizer + hparams["vocab_size"] = dataset.tokenizer.vocab_size + if model_type == "vae_rnn": + hparams["embedding_size"] = dataset.set_seq_size + else: # "vae_trans" + hparams["sequence_len"] = dataset.set_seq_size + elif model_type == "no_encoding": + hparams["latent_size"] = dataset.input_size + elif model_type in {"mlp_autoencoder", "vae_mlp"}: + hparams["input_size_enc"] = dataset.input_size + hparams["output_size_dec"] = dataset.target_size + else: # "mlp_predictor" + hparams["output_size"] = dataset.target_size + architecture["from_position"] = hparams["from_position"] + + return architecture + + +def build_dataset_and_architecture( + name: str, + data_path: str, + data_file: str, + dataset_type: str, + model_type: str, + hparams: Dict[str, Any], + **kwargs, +) -> Tuple[GranularDataset, Dict[str, Any]]: + """Build a dataset and an architecture configuration. + + Args: + name: dataset name. + data_path: path to the dataset. + data_file: data file name. + dataset_type: dataset type. Feasible values: "latentmodel", "smiles", "selfies" and "autoencoder". + model_type: model type. Feasible values: "vae_rnn", "vae_trans", "mlp_predictor", "no_encoding", "mlp_autoencoder" and "vae_mlp". + hparams: hyper-parameters for the data columns. + + Raises: + ValueError: in case the data file has an unsupported extension/format. + + Returns: + a tuple containing a granular dataset and a related architecture configuration. + """ + if data_file.endswith(".csv"): + data = pd.read_csv(f"{data_path}{os.path.sep}{data_file}") + elif data_file.endswith(".bz2") or data_file.endswith(".pkl"): + data = pd.read_pickle(f"{data_path}{os.path.sep}{data_file}") + else: + raise ValueError( + f"data_file={data_file} extension not supported. Use a compatible extension/format: {['.csv', '.bz2', '.pkl']}" + ) + data_columns = build_data_columns(hparams) + dataset = build_dataset(name, data, dataset_type, data_columns, hparams) + architecture = build_architecture(model_type, data_columns, dataset, hparams) + return dataset, architecture diff --git a/src/gt4sd/frameworks/granular/dataloader/sampler.py b/src/gt4sd/frameworks/granular/dataloader/sampler.py new file mode 100644 index 000000000..5620a4023 --- /dev/null +++ b/src/gt4sd/frameworks/granular/dataloader/sampler.py @@ -0,0 +1,69 @@ +""" +Sampler implementation. + +Reimplemented starting from: https://github.com/ncullen93/torchsample/blob/ea4d1b3975f68be0521941e733887ed667a1b46e/torchsample/samplers.py. +The main reason for reimplementation is to avoid to add a dependency and to control better the logger. +""" + +import logging +from typing import Iterator + +import numpy as np +import torch +from sklearn.model_selection import StratifiedShuffleSplit +from torch.utils.data import Sampler + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class StratifiedSampler(Sampler): + """Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.""" + + def __init__( + self, targets: torch.Tensor, batch_size: int, test_size: float = 0.5 + ) -> None: + """Construct a StratifiedSampler. + + Args: + targets: targets tensor. + batch_size: size of the batch. + test_size: proportion of samples in the test set. Defaults to 0.5. + """ + self.targets = targets + self.number_of_splits = int(self.targets.size(0) / batch_size) + self.test_size = test_size + + def gen_sample_array(self) -> np.ndarray: + """Get sample array. + + Returns: + sample array. + """ + splitter = StratifiedShuffleSplit( + n_splits=self.number_of_splits, test_size=self.test_size + ) + data_placeholder = torch.randn(self.targets.size(0), 2).numpy() + targets = self.targets.numpy() + splitter.get_n_splits(data_placeholder, targets) + train_index, test_index = next(splitter.split(data_placeholder, targets)) + return np.hstack([train_index, test_index]) + + def __iter__(self) -> Iterator[np.ndarray]: + """Get an iterator over the sample array. + + Returns: + sample array iterator. + + Yields: + a sample array. + """ + return iter(self.gen_sample_array()) + + def __len__(self) -> int: + """Length of the sampler. + + Returns: + the sampler length. + """ + return len(self.targets) diff --git a/src/gt4sd/frameworks/granular/ml/__init__.py b/src/gt4sd/frameworks/granular/ml/__init__.py new file mode 100644 index 000000000..cb676291c --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/__init__.py @@ -0,0 +1 @@ +"""ML module.""" diff --git a/src/gt4sd/frameworks/granular/ml/models/__init__.py b/src/gt4sd/frameworks/granular/ml/models/__init__.py new file mode 100644 index 000000000..8e16423a0 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/__init__.py @@ -0,0 +1,29 @@ +"""Model initialization module.""" + +from typing import Dict, Type + +from .activation import ACTIVATION_FACTORY # noqa: F401 +from .base_model import GranularBaseModel, GranularEncoderDecoderModel +from .loss import LOSS_FACTORY # noqa: F401 +from .mlp_auto_encoder import MlpAutoEncoder +from .mlp_predictor import MlpPredictor +from .no_encoding import NoEncoding +from .vae_mlp import VaeMlp +from .vae_rnn import VaeRnn +from .vae_trans import VaeTrans + +ARCHITECTURE_FACTORY: Dict[str, Type[GranularBaseModel]] = { + "vae_rnn": VaeRnn, + "vae_trans": VaeTrans, + "mlp_predictor": MlpPredictor, + "no_encoding": NoEncoding, + "mlp_autoencoder": MlpAutoEncoder, + "vae_mlp": VaeMlp, +} +AUTOENCODER_ARCHITECTURES = set( + [ + model_type + for model_type, model_class in ARCHITECTURE_FACTORY.items() + if issubclass(model_class, GranularEncoderDecoderModel) + ] +) diff --git a/src/gt4sd/frameworks/granular/ml/models/activation.py b/src/gt4sd/frameworks/granular/ml/models/activation.py new file mode 100644 index 000000000..9edfa9066 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/activation.py @@ -0,0 +1,12 @@ +"""Activations for granular models.""" + +from typing import Dict + +from torch import nn + +ACTIVATION_FACTORY: Dict[str, nn.Module] = { + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "softmax": nn.Softmax(), + "relu": nn.ReLU(), +} diff --git a/src/gt4sd/frameworks/granular/ml/models/base_model.py b/src/gt4sd/frameworks/granular/ml/models/base_model.py new file mode 100644 index 000000000..4d20abfa4 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/base_model.py @@ -0,0 +1,182 @@ +"""Base model for granular.""" + +from abc import abstractmethod +from argparse import ArgumentParser +from typing import Any, Dict, List, Tuple + +import torch +from torch import nn +from torch.distributions import Distribution + + +class GranularBaseModel(nn.Module): + """Base model class.""" + + position: int + from_position: List[int] + + def __init__(self, name: str, data: Dict[str, str], *args, **kwargs) -> None: + """Construct GranularBaseModel. + + Args: + name: model name. + data: data name mappings. + """ + super().__init__() + self.name = name + self.data = data + + def forward(self, x: Any, *args, **kwargs) -> Any: + """Forward pass in the model. + + Args: + x: model input. + + Returns: + model output. + """ + return self._run_step(x) + + @abstractmethod + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + pass + + @abstractmethod + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + pass + + @abstractmethod + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + pass + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + return parent_parser + + +class GranularEncoderDecoderModel(GranularBaseModel): + """Autoencoder model class.""" + + latent_size: int + + @abstractmethod + def decode(self, z: Any, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + + Returns: + decoded sample. + """ + pass + + @abstractmethod + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + pass + + def encode_decode(self, x: Any, *args, **kwargs) -> Any: + """Encode and decode a sample. + + Args: + x: input sample. + + Returns: + decoded sample. + """ + z = self.encode(x) + return self.decode(z) + + def inference(self, z: Any, *args, **kwargs) -> Any: + """Run the model in inference mode. + + Args: + z: sample. + + Returns: + generated output. + """ + return self.decode(z) + + def sample( + self, mu: torch.Tensor, log_var: torch.Tensor + ) -> Tuple[Distribution, Distribution, torch.Tensor]: + """Sample a point from a given mean and average following a normal log-likelihood. + + Args: + mu: mean tensor. + log_var: log varian tensor. + + Returns: + a tuple containing standard normal, localized normal and the sampled point. + """ + std = torch.exp(log_var / 2.0) + p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) + q = torch.distributions.Normal(mu, std) + z = q.rsample() + return p, q, z diff --git a/src/gt4sd/frameworks/granular/ml/models/loss.py b/src/gt4sd/frameworks/granular/ml/models/loss.py new file mode 100644 index 000000000..49996d4a1 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/loss.py @@ -0,0 +1,204 @@ +"""Losses for granular models.""" + +from typing import Any, Dict + +import torch +from torch import nn + + +class MSLELossNegMix9(nn.Module): + """MSLE loss negative mix 9.""" + + def __init__(self) -> None: + """Initialize the loss.""" + super().__init__() + self.mse = nn.MSELoss(reduction="sum") + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + pred2 = prediction.clone() + true2 = target.clone() + pred2[pred2 < 0] = 0 + pred2 = pred2 + 1e-6 + true2 = true2 + 1e-6 + pred3 = prediction.clone() + true3 = target.clone() + pred3[target < 0.0001] = 0 + true3[target < 0.0001] = 0 + pred4 = prediction.clone() + true4 = target.clone() + pred4[pred4 > 0] = 0 + true4[true4 < 2] = 0 + l4_ = self.mse(pred4, true4) + l1_ = self.mse(pred3 / (0.001 + true3), true3 / (0.001 + true3)) + l2_ = self.mse(prediction, target) + l3_ = self.mse(torch.log(pred2), torch.log(true2)) + l0_ = torch.abs(l1_ * 0.1 + l2_ * 1.0 + l3_ * 1.0e-5 + l4_ * 10.0) + return l0_ + + +class MSLELossNegMix91(nn.Module): + """MSLE loss negative mix 91.""" + + def __init__(self): + """Initialize the loss.""" + super().__init__() + self.mse = nn.MSELoss(reduction="sum") + self.mae = nn.L1Loss(reduction="sum") + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + pred2 = prediction.clone() + true2 = target.clone() + pred2[pred2 < 0] = 0 + pred2 = pred2 + 1e-6 + true2 = true2 + 1e-6 + l1_ = self.mae(prediction, target) + l2_ = self.mse(prediction, target) + l3_ = self.mse(torch.log(pred2), torch.log(true2)) + l0_ = torch.abs(l1_ * 0.3 + l2_ * 1.0 + l3_ * 1.0e-5) + return l0_ + + +class MseWithNans(nn.Module): + """MSE with NaNs handling.""" + + def __init__(self): + """Initialize the loss.""" + super().__init__() + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + mask = torch.isnan(target) + out = (prediction[~mask] - target[~mask]) ** 2 + loss = out.mean() + return loss + + +class MaeWithNans(nn.Module): + """MAE with NaNs handling.""" + + def __init__(self): + """Initialize the loss.""" + super().__init__() + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + mask = torch.isnan(target) + if sum(mask) == len(target): + return torch.tensor(0).type_as(prediction) + out = abs((prediction[~mask] - target[~mask])) + loss = sum(out) / len(prediction[~mask]) + return loss + + +class MSLELossNegMix92(nn.Module): + """MSLE loss negative mix 92.""" + + def __init__(self): + """Initialize the loss.""" + super().__init__() + self.mae = nn.L1Loss(reduction="mean") + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + l1_ = self.mae(prediction, target) + mask = target.ge(0.001) + l2_ = self.mae(prediction[mask], target[mask]) + l0_ = torch.abs(l1_ + l2_ * 10.0) + return l0_ + + +class MSLELossNegMix99(nn.Module): + """MSLE loss negative mix 99.""" + + def __init__(self): + """Initialize the loss.""" + super().__init__() + self.mse = nn.MSELoss(reduction="sum") + + def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> Any: + """Forward pass in the loss. + + Args: + prediction: predictions. + target: groundtruth. + + Returns: + loss value. + """ + mask = torch.isnan(target) + out = (prediction[~mask] - target[~mask]) ** 2 + loss = out.mean() + return loss + + +LOSS_FACTORY: Dict[str, nn.Module] = { + "mse": nn.MSELoss(), + "mse-sum": nn.MSELoss(reduction="sum"), + "mse-mean": nn.MSELoss(reduction="mean"), + "mse-with-nans": MseWithNans(), + "msewithnans": MseWithNans(), + "mae": nn.L1Loss(), + "mae-sum": nn.L1Loss(reduction="sum"), + "mae-mean": nn.L1Loss(reduction="mean"), + "mae-with-nans": MaeWithNans(), + "maewithnans": MaeWithNans(), + "bce": nn.BCELoss(), + "bce-with-logits": nn.BCEWithLogitsLoss(), + "bcewl": nn.BCEWithLogitsLoss(), + "loss9": MSLELossNegMix9(), + "l9": MSLELossNegMix9(), + "msle-neg-mix-9": MSLELossNegMix9(), + "loss91": MSLELossNegMix91(), + "l91": MSLELossNegMix91(), + "msle-neg-mix-91": MSLELossNegMix91(), + "loss92": MSLELossNegMix92(), + "l92": MSLELossNegMix92(), + "msle-neg-mix-92": MSLELossNegMix92(), + "loss99": MSLELossNegMix99(), + "l99": MSLELossNegMix99(), + "msle-neg-mix-99": MSLELossNegMix99(), + "crossentropyloss": nn.CrossEntropyLoss(), + "ce": nn.CrossEntropyLoss(), +} diff --git a/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/__init__.py b/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/__init__.py new file mode 100644 index 000000000..bde8a3f5a --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/__init__.py @@ -0,0 +1,3 @@ +"""Initialize MLP autoencoder module.""" + +from .core import MlpAutoEncoder # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/core.py b/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/core.py new file mode 100644 index 000000000..69c6ab2b6 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/mlp_auto_encoder/core.py @@ -0,0 +1,223 @@ +"""MLP autoencoder implementation.""" + +from argparse import ArgumentParser +from typing import Any, Dict, Tuple + +from ....arg_parser.utils import str2bool +from ..base_model import GranularEncoderDecoderModel +from ..loss import LOSS_FACTORY +from ..module import MlpDecoder, MlpEncoder + + +class MlpAutoEncoder(GranularEncoderDecoderModel): + """MlpAutoencoder - Multi Layer Perceptron autoencoder.""" + + def __init__( + self, + name: str, + position: int, + data: Dict[str, str], + input_size_enc: int = 256, + hidden_size_enc: int = 256, + n_layers_enc: int = 2, + activation_enc: str = "linear", + dropout_enc: float = 0.0, + hidden_size_dec: int = 256, + n_layers_dec: int = 2, + activation_dec: str = "linear", + dropout_dec: float = 0.0, + output_size_dec: int = 256, + latent_size: int = 196, + loss_function: str = "mse", + **kwargs, + ) -> None: + """Construct MlpAutoEncoder. + + Args: + name: model name. + position: position of the model. + data: data name mappings. + input_size_enc: encoder input size. Defaults to 256. + hidden_size_enc: encoder hidden size. Defaults to 256. + n_layers_enc: number of layers for the encoder. Defaults to 2. + activation_enc: activation function for the encoder. Defaults to "linear". + dropout_enc: encoder dropout rate. Defaults to 0.0. + hidden_size_dec: decoder hidden size. Defaults to 256. + n_layers_dec: number of layers for the decoder. Defaults to 2. + activation_dec: activation function for the decoder. Defaults to "linear". + dropout_dec: decoder dropout rate. Defaults to 0.0. + output_size_dec: decoder output size. Defaults to 256. + latent_size: size of the latent space. Defaults to 196. + loss_function: loss function. Defaults to "mse". + + Raises: + ValueError: in case the provided loss function is not supported. + """ + super().__init__(name=name, data=data) + self.position = position + self.input_key = f"{name}_{data['input']}" + self.target_key = f"{name}_{data['target']}" + + self.latent_size = latent_size + self.input_size_enc = input_size_enc + self.hidden_size_enc = hidden_size_enc + self.n_layers_enc = n_layers_enc + self.activation_enc = activation_enc + self.dropout_enc = dropout_enc + self.output_size_enc = latent_size + + self.hidden_size_dec = hidden_size_dec + self.n_layers_dec = n_layers_dec + self.activation_dec = activation_dec + self.dropout_dec = dropout_dec + self.output_size_dec = output_size_dec + + self.loss_function_name = loss_function.lower() + if self.loss_function_name not in LOSS_FACTORY: + raise ValueError( + f"loss_function={self.loss_function_name} not supported. Pick a valid one: {sorted(list(LOSS_FACTORY.keys()))}" + ) + self.loss_function = LOSS_FACTORY[self.loss_function_name] + + self.encoder = MlpEncoder( + input_size=input_size_enc, + hidden_size=hidden_size_enc, + output_size=latent_size, + n_layers=n_layers_enc, + activation=activation_enc, + dropout=dropout_enc, + ) + self.decoder = MlpDecoder( + latent_size=latent_size, + hidden_size=hidden_size_dec, + output_size=output_size_dec, + n_layers=n_layers_dec, + activation=activation_dec, + dropout=dropout_dec, + ) + self.epoch_counter = 0 + + def decode(self, z: Any, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + + Returns: + decoded sample. + """ + return self.decoder(z) + + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + return self.encoder(x) + + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + z = self.encoder(x) + x_out = self.decoder(z) + return z, x_out + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + x_target = target_data + + z, x_hat = self._run_step(x) + loss = self.loss_function(x_hat, x_target) + logs = {"reconstruction_loss": loss, "loss": loss} + return z, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + return self.step(input_data, target_data, device, current_epoch) + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--dataset_type_{name}", type=str) + parser.add_argument(f"--position_{name}", type=int, nargs="+") + parser.add_argument(f"--input_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--freeze_weights_{name}", type=str2bool) + parser.add_argument(f"--input_size_enc_{name}", type=int) + parser.add_argument(f"--hidden_size_enc_{name}", type=int) + parser.add_argument(f"--n_layers_enc_{name}", type=int) + parser.add_argument(f"--activation_enc_{name}", type=str) + parser.add_argument(f"--dropout_enc_{name}", type=float) + parser.add_argument(f"--hidden_size_dec_{name}", type=int) + parser.add_argument(f"--dropout_dec_{name}", type=float) + parser.add_argument(f"--n_layers_dec_{name}", type=int) + parser.add_argument(f"--activation_dec_{name}", type=str) + parser.add_argument(f"--ouptput_size_enc_{name}", type=int) + parser.add_argument(f"--latent_size_{name}", type=int) + parser.add_argument(f"--loss_function_{name}", type=str) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/__init__.py b/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/__init__.py new file mode 100644 index 000000000..9db978bf2 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/__init__.py @@ -0,0 +1,3 @@ +"""Initialize MLP predictor module.""" + +from .core import MlpPredictor # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/core.py b/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/core.py new file mode 100644 index 000000000..71851e584 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/mlp_predictor/core.py @@ -0,0 +1,180 @@ +"""MLP predictor implementation.""" + +import logging +from argparse import ArgumentParser +from typing import Any, Dict, List, Optional, Tuple + +from ....arg_parser.utils import str2bool +from ..base_model import GranularBaseModel +from ..loss import LOSS_FACTORY +from ..module import Mlp + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class MlpPredictor(GranularBaseModel): + """MlpPredictor - Multi Layer Perceptron predictor.""" + + def __init__( + self, + name: str, + from_position: List[int], + data: Dict[str, str], + input_size: int, + hidden_size: int, + output_size: int, + n_layers: int, + activation: str, + dropout: float, + loss_function: str, + class_weights: Optional[List[float]] = None, + **kwargs, + ) -> None: + """Construct MlpPredictor. + + Args: + name: model name. + from_position: list of input model positions. + data: data name mappings. + input_size: size of the input. + hidden_size: size of the hidden layers. + output_size: size of the output. + n_layers: number of layers. + activation: name of the activation. + dropout: dropout rate. + loss_function: name of the loss function. + class_weights: weights for the classes. Defaults to None, a.k.a., no weighting. + + Raises: + ValueError: in case the provided loss function is not supported. + """ + super().__init__(name=name, data=data) + self.from_position = from_position + self.target_key = name + "_" + data["target"] + self.loss_function_name = loss_function.lower() + if self.loss_function_name not in LOSS_FACTORY: + raise ValueError( + f"loss_function={self.loss_function_name} not supported. Pick a valid one: {sorted(list(LOSS_FACTORY.keys()))}" + ) + self.loss_function = LOSS_FACTORY[self.loss_function_name] + self.class_weights = class_weights + self.mlp = Mlp( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + n_layers=n_layers, + activation=activation, + dropout=dropout, + ) + + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + return self.mlp(x) + + def predict(self, x: Any, *args, **kwargs) -> Any: + """Forward pass in the model. + + Args: + x: model input. + + Returns: + model output. + """ + return self._run_step(x) + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + output = self._run_step(input_data) + loss = self.loss_function(output, target_data) + logs = {f"{self.loss_function_name}_loss": loss, "loss": loss} + return output, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + output = self._run_step(input_data) + loss = self.loss_function(output, target_data) + logs = {f"{self.loss_function_name}_loss": loss, "loss": loss} + + if self.loss_function_name == "bce": + output_label = (output > 0.5).float() + correct_label = (output_label == target_data).float().sum() + accuracy = correct_label / output_label.shape[0] + logs["accuracy"] = accuracy + return output, loss, logs + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--dataset_type_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--from_position_{name}", type=int, nargs="+") + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--freeze_weights_{name}", type=str2bool) + parser.add_argument(f"--n_layers_{name}", type=int) + parser.add_argument(f"--activation_{name}", type=str) + parser.add_argument(f"--dropout_{name}", type=float) + parser.add_argument(f"--loss_function_{name}", type=str) + parser.add_argument(f"--hidden_size_{name}", type=int) + parser.add_argument(f"--output_size_{name}", type=int) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/models/model_builder.py b/src/gt4sd/frameworks/granular/ml/models/model_builder.py new file mode 100644 index 000000000..0d649315a --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/model_builder.py @@ -0,0 +1,128 @@ +"""Model builder module.""" + +import logging +from collections import OrderedDict +from typing import Any, Dict, List +from typing import OrderedDict as OrderedDictType + +import torch + +from ....torch import device_claim +from . import ARCHITECTURE_FACTORY +from .base_model import GranularBaseModel + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +def build_model(architecture: Dict[str, Any]) -> GranularBaseModel: + """Build model from architecture configuration. + + Args: + architecture: architecture configuration. + + Returns: + built model. + """ + model_name = architecture["name"] + model_type = architecture["type"].lower() + if model_type not in ARCHITECTURE_FACTORY: + raise ValueError( + f"model_type={model_type} not supported. Pick a valid one: {sorted(ARCHITECTURE_FACTORY.keys())}" + ) + model = ARCHITECTURE_FACTORY[model_type]( + data=architecture["data"], **architecture["hparams"] + ) + + if architecture["start_from_checkpoint"]: + loaded_params = torch.load( + architecture["hparams"]["checkpoint_path"], map_location=device_claim(None) + ) + loaded_architecture_latent = loaded_params["hyper_parameters"][ + "architecture_latent_models" + ] + loaded_architecture_autoencoder = loaded_params["hyper_parameters"][ + "architecture_autoencoders" + ] + for arcihtecture_autoencoder in loaded_architecture_autoencoder: + if model_name == arcihtecture_autoencoder["name"]: + architecture = arcihtecture_autoencoder + for architecture_latent in loaded_architecture_latent: + if model_name == architecture_latent["name"]: + architecture = architecture_latent + loaded_state_dict: OrderedDictType[str, torch.Tensor] = OrderedDict() + for layer_name in loaded_params["state_dict"]: + state_model_name, *layer_name_elements = layer_name.split(".") + state_name = ".".join(layer_name_elements) + try: + checkpoint_model_name = architecture["hparams"]["checkpoint_model_name"] + except Exception: + checkpoint_model_name = None + if ( + state_model_name == model_name + or state_model_name == checkpoint_model_name + ): + loaded_state_dict[state_name] = loaded_params["state_dict"][layer_name] + model.load_state_dict(loaded_state_dict) + model.name = model_name + model.data = architecture["data"] + model.target_key = model_name + "_" + architecture["data"]["target"] + try: + freeze_weights = architecture["freeze_weights"] + except KeyError: + freeze_weights = None + + if freeze_weights: + for param in model.parameters(): + param.requires_grad = False + model.eval() + if model_type == "mlp_predictor": + model.from_position = architecture["from_position"] + else: + model.position = architecture["position"] + model.input_key = model_name + "_" + architecture["data"]["input"] + return model + + +def building_models(architectures: List[Dict[str, Any]]) -> List[GranularBaseModel]: + """Building models given architecture configurations. + + Args: + architectures: list of architecture configurations. + + Returns: + a list of models. + """ + return [build_model(architecture) for architecture in architectures] + + +def define_latent_models_input_size( + architecture_autoencoders: List[Dict[str, Any]], + architecture_latent_models: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Define latent models input size and return the updated configurations. + + Args: + architecture_autoencoders: list of autoencoder architecture configurations. + architecture_latent_models: list of latent model architecture configurations. + + Returns: + list of update latent model architecture configurations. + """ + size_autoencoder: Dict[str, int] = dict() + for architecture in architecture_autoencoders: + if architecture["position"] not in size_autoencoder.keys(): + size_autoencoder[architecture["position"]] = architecture["hparams"][ + "latent_size" + ] + else: + logger.warning(f"position for architecture={architecture} is not unique!") + + updated_architecture_latent_models = [] + for _, architecture in enumerate(architecture_latent_models): + architecture["hparams"]["input_size"] = sum( + [size_autoencoder[pos] for pos in architecture["from_position"]] + ) + updated_architecture_latent_models.append(architecture) + + return updated_architecture_latent_models diff --git a/src/gt4sd/frameworks/granular/ml/models/module.py b/src/gt4sd/frameworks/granular/ml/models/module.py new file mode 100644 index 000000000..18c35bcf3 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/module.py @@ -0,0 +1,1115 @@ +"""Generic modules.""" + +import copy +import math +from typing import Any, Callable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ....torch import get_device_from_tensor +from ...tokenizer import Tokenizer +from .activation import ACTIVATION_FACTORY + + +class Mlp(nn.Module): + """MLP module.""" + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + n_layers: int, + activation: str, + dropout: float, + **kwargs, + ) -> None: + """Construct Mlp. + + Args: + input_size: size of the input. + hidden_size: size of the hidden layers. + output_size: size of the output. + n_layers: number of layers. + activation: name of the activation. + dropout: dropout rate. + """ + super().__init__() + activation = activation.lower() + self.activation = ACTIVATION_FACTORY.get(activation, None) + self.first_layer = nn.Linear(input_size, hidden_size) + middle_layers: List[nn.Module] = list() + for _ in range(n_layers): + middle_layers.append(nn.Linear(hidden_size, hidden_size)) + middle_layers.append(nn.ReLU()) + middle_layers.append(nn.Dropout(p=dropout)) + self.middle_layers = nn.Sequential(*middle_layers) + self.last_layer = nn.Linear(hidden_size, output_size) + self.relu = nn.ReLU() + self.output_dim = output_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: model input. + + Returns: + model output. + """ + z = self.first_layer(x) + z = self.relu(z) + z = self.middle_layers(z) + z = self.last_layer(z) + if self.activation: + z = self.activation(z) + return z + + +class MlpEncoder(Mlp): + """MLP encoder.""" + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + n_layers: int, + activation: str, + dropout: float = 0.0, + **kwargs, + ) -> None: + """Construct MlpEncoder. + + Args: + input_size: size of the input. + hidden_size: size of the hidden layers. + output_size: size of the output. + n_layers: number of layers. + activation: name of the activation. + dropout: dropout rate. Defaults to 0.0. + """ + super().__init__( + input_size=input_size, + hidden_size=hidden_size, + output_size=output_size, + n_layers=n_layers, + activation=activation, + dropout=dropout, + ) + + +class MlpDecoder(Mlp): + """MLP decoder.""" + + def __init__( + self, + latent_size: int, + hidden_size: int, + output_size: int, + n_layers: int, + activation: str, + dropout: float = 0.0, + **kwargs, + ) -> None: + """Construct MlpEncoder. + + Args: + latent_size: size of the input. + hidden_size: size of the hidden layers. + output_size: size of the output. + n_layers: number of layers. + activation: name of the activation. + dropout: dropout rate. Defaults to 0.0. + """ + super().__init__( + input_size=latent_size, + hidden_size=hidden_size, + output_size=output_size, + n_layers=n_layers, + activation=activation, + dropout=dropout, + ) + + +class RnnEncoder(nn.Module): + """RNN encoder.""" + + def __init__( + self, + vocab_size: int, + embedding_size: int, + hidden_size: int = 256, + n_layers: int = 2, + bidirectional: bool = False, + latent_size: int = 196, + ) -> None: + """Construct RnnEncoder. + + Args: + vocab_size: size of the vocabulary. + embedding_size: size of the embedding vectors. + hidden_size: hidden size. Defaults to 256. + n_layers: number of layers. Defaults to 2. + bidirectional: whether the RNN cell is bidirectional. Defaults to False. + latent_size: latent size. Defaults to 196. + """ + super().__init__() + self.input_size = embedding_size + self.hidden_size = hidden_size + self.bidirectional = bidirectional + self.latent_size = latent_size + self.hidden_factor = (2 if bidirectional else 1) * n_layers + self.rnn = nn.GRU( + input_size=embedding_size, + hidden_size=hidden_size, + num_layers=n_layers, + bidirectional=bidirectional, + batch_first=True, + ) + self.embedding = nn.Embedding( + num_embeddings=vocab_size, embedding_dim=embedding_size + ) + + def forward( + self, input_sequence: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass in the model. + + Args: + input_sequence: input sequence tensor. + + Returns: + a tuple containing hidden state and embedded sequence. + """ + input_embedding = self.embedding(input_sequence) + _, hidden = self.rnn(input_embedding) + hidden = hidden.permute(1, 0, 2) + hidden = hidden.contiguous().view(hidden.size(0), -1) + return hidden, input_embedding + + +class RnnDecoder(nn.Module): + """RNN decoder.""" + + def __init__( + self, + vocab_size: int, + embedding_size: int, + hidden_size: int = 256, + n_layers: int = 2, + latent_size: int = 196, + ) -> None: + """Construct RnnDecoder. + + Args: + vocab_size: size of the vocabulary. + embedding_size: size of the embedding vectors. + hidden_size: hidden size. Defaults to 256. + n_layers: number of layers. Defaults to 2. + latent_size: latent size. Defaults to 196. + """ + super().__init__() + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.n_layers = n_layers + self.latent_size = latent_size + self.hidden_factor = n_layers + self.rnn = nn.GRU( + input_size=embedding_size, + hidden_size=hidden_size, + num_layers=n_layers, + batch_first=True, + ) + self.latent2hidden = torch.nn.Linear( + latent_size, hidden_size * self.hidden_factor + ) + self.outputs2vocab = torch.nn.Linear(hidden_size, vocab_size) + + def forward( + self, latent: torch.Tensor, input_embedding: torch.Tensor + ) -> torch.Tensor: + """Forward pass in the model. + + Args: + latent: latent tensor. + input_embedding: input embedding. + + Returns: + model output. + """ + hidden = self.latent2hidden(latent) + hidden = hidden.view(-1, self.hidden_factor, self.hidden_size) + hidden = hidden.permute(1, 0, 2).contiguous() + hidden = torch.tanh(hidden) + outputs, _ = self.rnn(input_embedding, hidden) + b, seq_len, hsize = outputs.size() + outputs = outputs.contiguous().view(-1, hsize) + outputs = self.outputs2vocab(outputs) + outputs = outputs.view(b, seq_len, self.vocab_size) + return outputs + + def inference_direct( + self, + latent: torch.Tensor, + embedding: nn.Module, + tokenizer: Tokenizer, + max_len: int, + ) -> Tuple[List[str], torch.Tensor]: + """Direct inference from latent space. + + Args: + latent: latent tensor. + embedding: embedding module. + tokenizer: tokenizer. + max_len: maximum sequence length. + + Returns: + a tuple containing decoded strings and indices. + """ + batch_size = latent.size(0) + hidden = self.latent2hidden(latent) + hidden = hidden.view(batch_size, self.hidden_factor, self.hidden_size) + hidden = hidden.permute(1, 0, 2).contiguous() + hidden = torch.tanh(hidden) + input_sequence = torch.full( + (batch_size, 1), tokenizer.sos_token_id, device=latent.device + ).long() + logits_list = [] + for t in range(max_len): + input_embedding = embedding(input_sequence) + output, hidden = self.rnn(input_embedding, hidden) + logits = self.outputs2vocab(output) + logits_list.append(logits) + input_sequence = torch.argmax(logits, dim=-1) + + logits_tensor = torch.cat(logits_list, dim=1) + token_indices = torch.argmax(logits_tensor, dim=-1) + decoded_texts = [] + for index in range(batch_size): + tokens = [ + tokenizer.convert_id_to_token(vocab_index.item()) + for vocab_index in token_indices[index] + ] + text = "".join(tokens).split()[0] + decoded_texts.append(text) + return decoded_texts, token_indices + + +def attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + dropout: Optional[nn.Module] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention (adapted from Viswani et al.). + + Args: + query: query tensor. + key: key tensor. + value: value tesor. + mask: mask to apply on attention score. Defaults to None, a.k.a., no mask. + dropout: dropout layer. Defaults to None, a.k.a., no dropout. + + Returns: + a tuple containing the applied attention and the attention weights. + """ + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +def clones(module: nn.Module, n: int) -> nn.Module: + """Produce N identical layers (adapted from http://nlp.seas.harvard.edu/2018/04/03/attention.html). + + Args: + module: a module. + n: number of clones. + + Returns: + a module list. + """ + return torch.nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def subsequent_mask(size: int) -> torch.Tensor: + """Mask out subsequent positions (adapted from http://nlp.seas.harvard.edu/2018/04/03/attention.html). + + Args: + size: size of the attention matrix. + + Returns: + the mask tensor. + """ + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") + return torch.from_numpy(subsequent_mask) == 0 + + +class ListModule(torch.nn.Module): + """Create single pytorch module from list of modules.""" + + def __init__(self, *args) -> None: + """Construct ListModule.""" + super().__init__() + idx = 0 + for module in args: + self.add_module(str(idx), module) + idx += 1 + + def __getitem__(self, idx: int) -> Any: + """Get item from the module list. + + Args: + idx: index of the item. + + Raises: + IndexError: in case the index is out of range. + + Returns: + the item. + """ + if idx < 0 or idx >= len(self._modules): + raise IndexError("index {} is out of range".format(idx)) + it = iter(self._modules.values()) + for i in range(idx): + next(it) + return next(it) + + def __iter__(self) -> Any: + """An iterator over the module list values. + + Returns: + the iterator over values. + """ + return iter(self._modules.values()) + + def __len__(self): + """Length of the module list. + + Returns: + the number of modules. + """ + return len(self._modules) + + +class MultiHeadedAttention(nn.Module): + """"Multihead attention implementation (based on Vaswani et al.).""" + + def __init__(self, h, d_model, dropout=0.1) -> None: + """Construct MultiHeadedAttention. + + Args: + h: number of heads. + d_model: model size. + dropout: dropout rate. Defaults to 0.1. + """ + super().__init__() + assert d_model % h == 0 + # we assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + return_attn: bool = False, + ) -> Any: + """Forward pass in the model. + + Args: + query: query tensor. + key: key tensor. + value: value tesor. + mask: mask to apply on attention score. Defaults to None, a.k.a., no mask. + return_attn: whether to return the attention matrix instead of the linear layer output. + Defaults to False, a.k.a, do not return attention. + + Returns: + either the last layer output of the attention matrix. + """ + if mask is not None: + # Same mask applied to all h heads + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) do all the linear projections in batch from d_model => h x d_k + query, key, value = [ + linear_layer(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for linear_layer, x in zip(self.linears, (query, key, value)) # type:ignore + ] + + # 2) apply attention on all the projected vectors in batch + x, self.attn = attention( # type:ignore + query, key, value, mask=mask, dropout=self.dropout + ) # type:ignore + + # 3) "concat" using a view and apply a final linear + x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) + if return_attn: + return self.attn + else: + return self.linears[-1](x) # type:ignore + + +class PositionwiseFeedForward(nn.Module): + """Feed forward implementation.""" + + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1) -> None: + """Construct PositionwiseFeedForward. + + Args: + d_model: model size. + d_ff: feed forward size. + dropout: dropout rate. Defaults to 0.1. + """ + super().__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + feed forward output. + """ + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class ConvBottleneck(nn.Module): + """Set of convolutional layers to reduce memory matrix to single latent vector.""" + + def __init__(self, size: int, number_of_layers: int = 3) -> None: + """Construct ConvBottleneck. + + Args: + size: input size. + number_of_layers: convolutional layers number. Defaults to 3. + """ + super().__init__() + conv_layers = [] + in_d = size + first = True + for i in range(number_of_layers): + out_d = int((in_d - 64) // 2 + 64) + if first: + kernel_size = 9 + first = False + else: + kernel_size = 8 + if i == 2: + out_d = 64 + conv_layers.append( + nn.Sequential(nn.Conv1d(in_d, out_d, kernel_size), nn.MaxPool1d(2)) + ) + in_d = out_d + self.conv_layers = ListModule(*conv_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + for conv in self.conv_layers: + x = F.relu(conv(x)) + return x + + +class DeconvBottleneck(nn.Module): + """Set of deconvolutional layers to reshape latent vector back into memory matrix.""" + + def __init__(self, size: int, seq_len: int, dim_factor: int) -> None: + """Construct DeconvBottleneck. + + Args: + size: size of the deconvolutional padding. + seq_len: length of the sequence. + dim_factor: dimensionality factor. + """ + super().__init__() + deconv_layers = [] + + in_d = 64 + + out_fac = 9 * dim_factor + 8 + out_fac = out_fac - 1 + 50 + 1 + diff_seq = out_fac - seq_len + + for i in range(3): + out_d = (size - in_d) // 4 + in_d + stride = 3 + padding = 3 + dilation = 1 + kernel_size = 11 + output_padding = 0 + if i == 2: + out_d = size + stride = 1 + dilation = 5 + if diff_seq % 2 == 0: + padding = int(diff_seq / 2) + output_padding = 0 + else: + padding = math.ceil(diff_seq / 2) + output_padding = 1 + + deconv_layers.append( + nn.Sequential( + nn.ConvTranspose1d( + in_d, + out_d, + kernel_size, + dilation=dilation, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + ) + ) + in_d = out_d + self.deconv_layers = ListModule(*deconv_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + for deconv in self.deconv_layers: + x = F.relu(deconv(x)) + return x + + +class Embeddings(nn.Module): + "Transforms input token id tensors to size d_model embeddings." + + def __init__(self, d_model: int, vocab_size: int) -> None: + """Costruct Embeddings. + + Args: + d_model: size of the embedding vectors. + vocab_size: size of the vocabulary. + """ + super().__init__() + self.lut = nn.Embedding(vocab_size, d_model) + self.d_model = d_model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + return self.lut(x) * math.sqrt(self.d_model) + + +class PositionalEncoding(nn.Module): + """Static sinusoidal positional encoding layer.""" + + def __init__(self, d_model: int, dropout: float, max_len: int = 5000) -> None: + """Construct PositionalEncoding. + + Args: + d_model: model size. + dropout: dropout rate. + max_len: maximum sequence length. Defaults to 5000. + """ + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + # compute the positional encodings once in log space + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + x = x + torch.autograd.Variable( + self.pe[:, : x.size(1)], requires_grad=False # type:ignore + ) + return self.dropout(x) + + +class TorchLayerNorm(nn.Module): + """Layer normalization using torch BatchNorm1d.""" + + def __init__(self, features: int, eps=1e-6) -> None: + """Construct TorchLayerNorm. + + Args: + features: number of features. + eps: espilon to add to denominator for numerical stability. Defaults to 1e-6. + """ + super().__init__() + self.bn = nn.BatchNorm1d(features, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + return self.bn(x) + + +class LayerNorm(nn.Module): + """Custom layer normalization.""" + + def __init__(self, features: int, eps=1e-6) -> None: + """Construct LayerNorm. + + Args: + features: number of features. + eps: espilon to add to denominator for numerical stability. Defaults to 1e-6. + """ + super().__init__() + self.a = nn.Parameter(torch.ones(features)) + self.b = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + + Returns: + model output. + """ + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a * (x - mean) / (std + self.eps) + self.b + + +class SublayerConnection(nn.Module): + """A residual connection followed by a layer normalization. + + Note for code simplicity the norm is first as opposed to last. A dropout layer + is also applied. + """ + + def __init__(self, size: int, dropout: float) -> None: + super().__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, sublayer: Callable) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + sublayer: a callable returning a tensor. + + Returns: + model output. + """ + return x + self.dropout(sublayer(self.norm(x))) + + +class TransformerEncoder(nn.Module): + """Base transformer encoder architecture.""" + + def __init__( + self, + hidden_size: int, + ff_size: int, + seq_len: int, + dropout: float, + heads: int, + n_layers_enc: int, + vocab_size: int, + bypass_bottleneck: bool, + ) -> None: + """Construct TransformerEncoder. + + Args: + hidden_size: hidden size. + ff_size: feed forward size. + seq_len: sequence length. + dropout: dropout rate. + heads: number of heads. + n_layers_enc: number of encoding layers. + vocab_size: vocabulary size. + bypass_bottleneck: whether the bottleneck should be by passed. + """ + super().__init__() + + self.position = PositionalEncoding(hidden_size, dropout) + self.embedding = nn.Sequential( + Embeddings(hidden_size, vocab_size * 2), self.position + ) + + self.self_attn = MultiHeadedAttention(heads, hidden_size) + self.feed_forward = PositionwiseFeedForward(hidden_size, ff_size, dropout) + layer = TransformerEncoderLayer( + hidden_size, seq_len, self.self_attn, self.feed_forward, dropout + ) + self.layers = clones(layer, n_layers_enc) + + self.conv_bottleneck = ConvBottleneck(hidden_size) + self.norm = LayerNorm(hidden_size) + + self.bypass_bottleneck = bypass_bottleneck + conv_output_shape = self.calc_output_shape(seq_len, hidden_size) + self.conv_output_len = conv_output_shape[1] * conv_output_shape[2] + self.conv_output_shape = conv_output_shape + + def calc_output_shape(self, seq_len: int, hidden_size: int): + """Compute output shape. + + Args: + seq_len: sequence length. + hidden_size: hidden size. + + Returns: + convolutional bottleneck output shape. + """ + x = torch.randn((1, hidden_size, seq_len)) + x_out = self.conv_bottleneck(x) + return x_out.shape + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + mask: mask to apply in the attention layer. + + Returns: + model output. + """ + x = self.embedding(x) + for _, attn_layer in enumerate(self.layers): # type:ignore + x = attn_layer(x, mask) + mem = self.norm(x) + mem = mem.permute(0, 2, 1) + mem = self.conv_bottleneck(mem) + mem = mem.contiguous().view(mem.size(0), -1) + return mem + + +class TransformerEncoderLayer(nn.Module): + """Self-attention/feedforward implementation.""" + + def __init__( + self, + size: int, + seq_len: int, + self_attn: nn.Module, + feed_forward: nn.Module, + dropout: float, + ) -> None: + """Construct TransformerEncoderLayer. + + Args: + size: model size. + seq_len: sequence length. + self_attn: self-attention layer. + feed_forward: feed forward layer. + dropout: droupout rate. + """ + super().__init__() + self.size = size + self.seq_len = seq_len + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(self.size, dropout), 2) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor, return_attn: bool = False + ) -> Any: + """Forward pass in the model. + + Args: + x: input tensor. + mask: mask to apply in the attention layer. + return_attn: whether to return the attention together with the output. + Defaults to False, return only encoder output. + + Returns: + model output. + """ + if return_attn: + attn = self.self_attn(x, x, x, mask, return_attn=True) + x = self.sublayer[0]( # type:ignore + x, lambda x: self.self_attn(x, x, x, mask) + ) + return self.sublayer[1](x, self.feed_forward), attn # type:ignore + else: + x = self.sublayer[0]( # type:ignore + x, lambda x: self.self_attn(x, x, x, mask) + ) + return self.sublayer[1](x, self.feed_forward) # type:ignore + + +class TransformerDecoder(nn.Module): + """Base transformer decoder architecture.""" + + def __init__( + self, + hidden_size: int, + ff_size: int, + seq_len: int, + dropout: float, + heads: int, + n_layers_dec: int, + latent_size: int, + vocab_size: int, + bypass_bottleneck: bool, + deconv_shape: Tuple[int, int, int], + ) -> None: + """Construct TransformerDecoder. + + Args: + hidden_size: hidden size. + ff_size: feed forward size. + seq_len: sequence length. + dropout: dropout rate. + heads: number of heads. + n_layers_enc: number of encoding layers. + latent_size: latent size. + vocab_size: vocabulary size. + bypass_bottleneck: whether the bottleneck should be by passed. + deconv_shape: shape of the deconvoluted samples. A tuple with three + dimensions. + """ + super().__init__() + + self.position = PositionalEncoding(hidden_size, dropout) + self.embedding = nn.Sequential( + Embeddings(hidden_size, vocab_size), self.position + ) + self.attn_enc = MultiHeadedAttention(heads, hidden_size) + self.ff_enc = PositionwiseFeedForward(hidden_size, ff_size, dropout) + self.attn_dec_1 = MultiHeadedAttention(heads, hidden_size) + self.attn_dec_2 = MultiHeadedAttention(heads, hidden_size) + + self.ff_dec = PositionwiseFeedForward(hidden_size, ff_size, dropout) + + encoder_layers = TransformerEncoderLayer( + hidden_size, seq_len, self.attn_enc, self.ff_enc, dropout + ) + decoder_layers = TransformerDecoderLayer( + hidden_size, + seq_len, + self.attn_dec_1, + self.attn_dec_2, + self.ff_dec, + dropout, + ) + + self.final_encodes = clones(encoder_layers, 1) + self.layers = clones(decoder_layers, n_layers_dec) + self.norm = LayerNorm(hidden_size) + self.bypass_bottleneck = bypass_bottleneck + self.hidden_size = hidden_size + self.seq_len = seq_len + self.outputs2vocab = torch.nn.Linear(hidden_size, vocab_size) + self.deconv_shape = deconv_shape + self.deconv_bottleneck = DeconvBottleneck( + hidden_size, seq_len=seq_len, dim_factor=deconv_shape[2] + ) + self.linear = nn.Linear(latent_size, deconv_shape[2] * deconv_shape[1]) + + def forward( + self, + x: torch.Tensor, + mem: torch.Tensor, + src_mask: torch.Tensor, + tgt_mask: torch.Tensor, + ) -> torch.Tensor: + """Forward pass in the model. + + Args: + x: input tensor. + mem: memory tensor. + src_mask: source sequence mask. + tgt_mask: target sequence mask. + + Returns: + model output. + """ + x = self.embedding(x) + if not self.bypass_bottleneck: + mem = F.relu(self.linear(mem)) + mem = mem.view(-1, 64, self.deconv_shape[2]) + mem = self.deconv_bottleneck(mem) + mem = mem.permute(0, 2, 1) + for final_encode in self.final_encodes: # type:ignore + mem = final_encode(mem, src_mask) + mem = self.norm(mem) + for _, attn_layer in enumerate(self.layers): # type:ignore + x = attn_layer(x, mem, mem, src_mask, tgt_mask) + x = self.norm(x) + x = self.outputs2vocab(F.relu(x)) + return x + + def inference_direct( + self, + latent: torch.Tensor, + mask_lengths: torch.Tensor, + tokenizer: Tokenizer, + ) -> Tuple[List[str], torch.Tensor]: + """Direct inference from latent space. + + Args: + latent: latent tensor. + mask_lengths: masking tensor. + tokenizer: tokenizer. + + Returns: + a tuple containing decoded strings and indices. + """ + device = get_device_from_tensor(latent) + batch_size = latent.size(0) + token_indices = torch.full( + (batch_size, 1), tokenizer.sos_token_id, device=device + ).long() + + src_mask = torch.zeros((latent.shape[0], 1, self.seq_len), device=device) + + for index in range(mask_lengths.shape[0]): + mask_len = int(mask_lengths[index].item()) + src_mask[index, :, :mask_len] = torch.ones((1, 1, mask_len), device=device) + self.eval() + for i in range(self.seq_len - 1): + trg_mask = subsequent_mask(token_indices.size(1)).long().to(device) + logits = self( + torch.autograd.Variable(token_indices), latent, src_mask, trg_mask + ) + + prob = F.softmax(logits[:, i, :], dim=-1) + _, next_token = torch.max(prob, dim=1) + + next_token = next_token.unsqueeze(1) + token_indices = torch.cat([token_indices, next_token], dim=1) + + decoded_texts = [] + for index in range(batch_size): + tokens = [ + tokenizer.convert_id_to_token(vocab_index.item()) + for vocab_index in token_indices[index] + ] + text = "".join(tokens).split()[0] + decoded_texts.append(text) + return decoded_texts, token_indices + + +class TransformerDecoderLayer(nn.Module): + """Self-attention/source-attention/feedforward implementation.""" + + def __init__( + self, + size: int, + seq_len: int, + self_attn: nn.Module, + src_attn: nn.Module, + feed_forward: nn.Module, + dropout: float, + ) -> None: + """Construct TransformerDecoderLayer. + + Args: + size: model size. + seq_len: sequence length. + self_attn: self-attention layer. + src_attn: source attention layer. + feed_forward: feed forward layer. + dropout: droupout rate. + """ + super().__init__() + self.size = size + self.tgt_len = seq_len + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(self.size, dropout), 3) + + def forward( + self, + x: torch.Tensor, + memory_key: torch.Tensor, + memory_val: torch.Tensor, + src_mask: torch.Tensor, + tgt_mask: torch.Tensor, + return_attn: bool = False, + ) -> Any: + """Forward pass in the model. + + Args: + x: input tensor + memory_key: memory key tensor. + memory_val: memory value tensor.s + src_mask: mask to apply in the source attention layer. + tgt_mask: mask to apply in the target attention layer. + return_attn: whether to return the attention together with the output. + Defaults to False, return only encoder output. + + Returns: + model output. + """ + m_key = memory_key + m_val = memory_val + if return_attn: + x = self.sublayer[0]( # type:ignore + x, lambda x: self.self_attn(x, x, x, tgt_mask) + ) + src_attn = self.src_attn(x, m_key, m_val, src_mask, return_attn=True) + x = self.sublayer[1]( # type:ignore + x, lambda x: self.src_attn(x, m_key, m_val, src_mask) + ) + return self.sublayer[2](x, self.feed_forward), src_attn # type:ignore + else: + x = self.sublayer[0]( # type:ignore + x, lambda x: self.self_attn(x, x, x, tgt_mask) + ) + x = self.sublayer[1]( # type:ignore + x, lambda x: self.src_attn(x, m_key, m_val, src_mask) + ) + return self.sublayer[2](x, self.feed_forward) # type:ignore diff --git a/src/gt4sd/frameworks/granular/ml/models/no_encoding/__init__.py b/src/gt4sd/frameworks/granular/ml/models/no_encoding/__init__.py new file mode 100644 index 000000000..ebd7f873f --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/no_encoding/__init__.py @@ -0,0 +1,3 @@ +"""Initialize no encoding module.""" + +from .core import NoEncoding # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/no_encoding/core.py b/src/gt4sd/frameworks/granular/ml/models/no_encoding/core.py new file mode 100644 index 000000000..7f1bec009 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/no_encoding/core.py @@ -0,0 +1,180 @@ +"""NoEncoding implementation.""" + +from argparse import ArgumentParser +from typing import Any, Dict, Tuple + +from ....arg_parser.utils import str2bool +from ..base_model import GranularEncoderDecoderModel + + +class NoEncoding(GranularEncoderDecoderModel): + """NoEncoding module for adding inputs directly in the latent space.""" + + def __init__( + self, + name: str, + position: int, + data: Dict[str, str], + latent_size: int = 2, + **kwargs, + ) -> None: + """Construct NoEncoding. + + Args: + name: model name. + position: position of the model. + data: data name mappings. + latent_size: latent size. Defaults to 2. + """ + super().__init__(name=name, data=data) + self.position = position + self.input_key = f"{name}_{data['input']}" + self.target_key = f"{name}_{data['target']}" + self.latent_size = latent_size + + def decode(self, z: Any, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + + Returns: + decoded sample. + """ + return z + + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + return x + + def inference(self, z: Any, *args, **kwargs) -> Any: + """Run the model in inference mode. + + Args: + z: sample. + + Returns: + generated output. + """ + return z + + def forward(self, x: Any, *args, **kwargs) -> Any: + """Forward pass in the model. + + Args: + x: model input. + + Returns: + model output. + """ + return x + + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + return x + + def encode_decode(self, x: Any, *args, **kwargs) -> Any: + """Encode and decode a sample. + + Args: + x: input sample. + + Returns: + decoded sample. + """ + z, x_out = self._run_step(x) + return z, x_out + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + z = input_data + + loss = 0 + logs = {"loss": loss} + + return z, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + z = input_data + + loss = 0 + logs = {"loss": loss} + return z, loss, logs + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + update parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--position_{name}", type=int, nargs="+") + parser.add_argument(f"--input_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--latent_size_{name}", type=int) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/models/utils.py b/src/gt4sd/frameworks/granular/ml/models/utils.py new file mode 100644 index 000000000..386b3fda6 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/utils.py @@ -0,0 +1,37 @@ +"""Model utilities.""" + + +class KLAnnealer: + """Annealer scaling KL weights (beta) linearly according to the number of epochs.""" + + def __init__( + self, kl_low: float, kl_high: float, n_epochs: int, start_epoch: int + ) -> None: + """Construct KLAnnealer. + + Args: + kl_low: low KL weight. + kl_high: high KL weight. + n_epochs: number of epochs. + start_epoch: starting epoch. + """ + self.kl_low = kl_low + self.kl_high = kl_high + self.n_epochs = n_epochs + self.start_epoch = start_epoch + self.kl = (self.kl_high - self.kl_low) / (self.n_epochs - self.start_epoch) + + def __call__(self, epoch: int) -> float: + """Call the annealer. + + Args: + epoch: current epoch number. + + Returns: + the beta weight. + """ + k = (epoch - self.start_epoch) if epoch >= self.start_epoch else 0 + beta = self.kl_low + k * self.kl + if beta > self.kl_high: + beta = self.kl_high + return beta diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_mlp/__init__.py b/src/gt4sd/frameworks/granular/ml/models/vae_mlp/__init__.py new file mode 100644 index 000000000..766d26e5c --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_mlp/__init__.py @@ -0,0 +1,3 @@ +"""Initialize MLP variational autoencoder module.""" + +from .core import VaeMlp # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_mlp/core.py b/src/gt4sd/frameworks/granular/ml/models/vae_mlp/core.py new file mode 100644 index 000000000..89fe4b876 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_mlp/core.py @@ -0,0 +1,285 @@ +"""VaeMlp implementation.""" + +from argparse import ArgumentParser +from typing import Any, Dict, Tuple + +from torch import nn + +from ....arg_parser.utils import str2bool +from ..base_model import GranularEncoderDecoderModel +from ..loss import LOSS_FACTORY +from ..module import MlpDecoder, MlpEncoder +from ..utils import KLAnnealer + + +class VaeMlp(GranularEncoderDecoderModel): + """VaeMlp - variational encoder using MLP with Gaussian prior and approximate posterior.""" + + def __init__( + self, + name: str, + position: int, + data: Dict[str, str], + input_size_enc: int = 256, + hidden_size_enc: int = 256, + n_layers_enc: int = 2, + activation_enc: str = "linear", + dropout_enc: float = 0.0, + hidden_size_dec: int = 256, + n_layers_dec: int = 2, + activation_dec: str = "linear", + dropout_dec: float = 0.0, + output_size_dec: int = 256, + latent_size: int = 196, + loss_function: str = "mse", + kl_low: float = 0.0, + kl_high: float = 0.1, + kl_n_epochs: int = 100, + kl_start_epoch: int = 0, + **kwargs, + ) -> None: + """Construct VaeMlp. + + Args: + name: model name. + position: position of the model. + data: data name mappings. + input_size_enc: encoder input size. Defaults to 256. + hidden_size_enc: encoder hidden size. Defaults to 256. + n_layers_enc: number of layers for the encoder. Defaults to 2. + activation_enc: activation function for the encoder. Defaults to "linear". + dropout_enc: encoder dropout rate. Defaults to 0.0. + hidden_size_dec: decoder hidden size. Defaults to 256. + n_layers_dec: number of layers for the decoder. Defaults to 2. + activation_dec: activation function for the decoder. Defaults to "linear". + dropout_dec: decoder dropout rate. Defaults to 0.0. + output_size_dec: decoder output size. Defaults to 256. + latent_size: size of the latent space. Defaults to 196. + loss_function: loss function. Defaults to "mse". + kl_low: low KL weight. + kl_high: high KL weight. + kl_n_epochs: KL number of epochs. + kl_start_epoch: KL starting epoch. + + Raises: + ValueError: in case the provided loss function is not supported. + """ + super().__init__(name=name, data=data) + self.position = position + self.input_key = f"{name}_{data['input']}" + self.target_key = f"{name}_{data['target']}" + + self.latent_size = latent_size + self.hidden_size_enc = hidden_size_enc + self.hidden_size_dec = hidden_size_dec + self.input_size_enc = input_size_enc + self.hidden_size_enc = hidden_size_enc + self.n_layers_enc = n_layers_enc + self.activation_enc = activation_enc + self.dropout_enc = dropout_enc + self.hidden_size_dec = hidden_size_dec + self.n_layers_dec = n_layers_dec + self.activation_dec = activation_dec + self.dropout_dec = dropout_dec + self.output_size_dec = output_size_dec + + self.loss_function_name = loss_function.lower() + if self.loss_function_name not in LOSS_FACTORY: + raise ValueError( + f"loss_function={self.loss_function_name} not supported. Pick a valid one: {sorted(list(LOSS_FACTORY.keys()))}" + ) + self.loss_function = LOSS_FACTORY[self.loss_function_name] + + self.encoder = MlpEncoder( + input_size=input_size_enc, + hidden_size=hidden_size_enc, + output_size=hidden_size_enc, + n_layers=n_layers_enc, + activation=activation_enc, + dropout=dropout_enc, + ) + self.fc_mu = nn.Linear(hidden_size_enc, latent_size) + self.fc_var = nn.Linear(hidden_size_enc, latent_size) + self.decoder = MlpDecoder( + latent_size=latent_size, + hidden_size=hidden_size_dec, + output_size=output_size_dec, + n_layers=n_layers_dec, + activation=activation_dec, + dropout=dropout_dec, + ) + + self.epoch_counter = 0 + + self.kl_annealer = KLAnnealer( + kl_low=kl_low, + kl_high=kl_high, + n_epochs=kl_n_epochs, + start_epoch=kl_start_epoch, + ) + + def decode(self, z: Any, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + + Returns: + decoded sample. + """ + return self.decoder(z) + + def _sampling_step(self, x: Any, *args, **kwargs) -> Any: + """Run a sampling step in the model. + + Args: + x: model input. + + Returns: + model sampling step output. + """ + x = self.encoder(x) + mu = self.fc_mu(x) + log_var = self.fc_var(x) + return self.sample(mu, log_var) + + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + _, _, z = self._sampling_step(x) + return z + + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + p, q, z = self._sampling_step(x) + return z, self.decoder(z), p, q + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + + z, x_hat, p, q = self._run_step(x) + + x_hat = x_hat.view(-1, x_hat.size(-1)) + + reconstruction_loss = self.loss_function(x_hat, x) + + log_qz = q.log_prob(z) + log_pz = p.log_prob(z) + kl_scaling_factor = self.kl_annealer(current_epoch) + kl = log_qz - log_pz + kl = kl.mean() + kl_scaled = kl * kl_scaling_factor + + loss = kl_scaled + reconstruction_loss + logs = { + "reconstruction_loss": reconstruction_loss, + "kl_scaled": kl_scaled, + "kl_unscaled": kl, + "kl_scaling_factor": kl_scaling_factor, + "loss": loss, + } + + return z, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + return self.step( + input_data=input_data, + target_data=target_data, + device=device, + current_epoch=current_epoch, + ) + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--dataset_type_{name}", type=str) + parser.add_argument(f"--position_{name}", type=int, nargs="+") + parser.add_argument(f"--input_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--freeze_weights_{name}", type=str2bool) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--input_size_enc_{name}", type=int) + parser.add_argument(f"--hidden_size_enc_{name}", type=int) + parser.add_argument(f"--n_layers_enc_{name}", type=int) + parser.add_argument(f"--activation_enc_{name}", type=str) + parser.add_argument(f"--dropout_enc_{name}", type=float) + parser.add_argument(f"--hidden_size_dec_{name}", type=int) + parser.add_argument(f"--dropout_dec_{name}", type=float) + parser.add_argument(f"--n_layers_dec_{name}", type=int) + parser.add_argument(f"--activation_dec_{name}", type=str) + parser.add_argument(f"--ouptput_size_enc_{name}", type=int) + parser.add_argument(f"--loss_function_{name}", type=str) + parser.add_argument(f"--latent_size_{name}", type=int) + parser.add_argument(f"--kl_low_{name}", type=float) + parser.add_argument(f"--kl_high_{name}", type=float) + parser.add_argument(f"--kl_n_epochs_{name}", type=int) + parser.add_argument(f"--kl_start_epoch_{name}", type=int) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_rnn/__init__.py b/src/gt4sd/frameworks/granular/ml/models/vae_rnn/__init__.py new file mode 100644 index 000000000..18d17c8e1 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_rnn/__init__.py @@ -0,0 +1,3 @@ +"""Initialize RNN variational autoencoder module.""" + +from .core import VaeRnn # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_rnn/core.py b/src/gt4sd/frameworks/granular/ml/models/vae_rnn/core.py new file mode 100644 index 000000000..02b0f67a6 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_rnn/core.py @@ -0,0 +1,339 @@ +"""VaeRnn implementation.""" + +from argparse import ArgumentParser +from typing import Any, Dict, Tuple + +import torch +from torch import nn + +from ....arg_parser.utils import str2bool +from ....tokenizer import Tokenizer +from ..base_model import GranularEncoderDecoderModel +from ..loss import LOSS_FACTORY +from ..module import RnnDecoder, RnnEncoder +from ..utils import KLAnnealer + + +class VaeRnn(GranularEncoderDecoderModel): + """VaeRnn - variational encoder using RNN with Gaussian prior and approximate posterior.""" + + def __init__( + self, + name: str, + position: int, + data: Dict[str, str], + vocab_size: int, + embedding_size: int, + tokenizer: Tokenizer, + hidden_size_enc: int = 265, + n_layers_enc: int = 2, + hidden_size_dec: int = 265, + n_layers_dec: int = 2, + bidirectional: bool = False, + latent_size: int = 196, + teacher_forcing: bool = True, + loss_function: str = "ce", + kl_low: float = 0.0, + kl_high: float = 0.1, + kl_n_epochs: int = 100, + kl_start_epoch: int = 0, + inference_check_frequency: int = 50, + **kwargs, + ) -> None: + """Construct VaeRnn. + + Args: + name: model name. + position: position of the model. + data: data name mappings. + vocab_size: size of the vocabulary. + embedding_size: size of the embedding vectors. + tokenizer: tokenizer. + hidden_size_enc: encoder hidden size. Defaults to 256. + n_layers_enc: number of layers for the encoder. Defaults to 2. + hidden_size_dec: decoder hidden size. Defaults to 256. + n_layers_dec: number of layers for the decoder. Defaults to 2. + bidirectional: whether the RNN cell is bidirectional. Defaults to False. + latent_size: latent size. Defaults to 196. + teacher_forcing: whether to teacher forcing. Defaults to True. + loss_function: loss function. Defaults to "ce". + kl_low: low KL weight. Defaults to 0.0. + kl_high: high KL weight. Defaults to 0.1. + kl_n_epochs: KL number of epochs. Defaults to 100. + kl_start_epoch: KL starting epoch. Defaults to 0. + inference_check_frequency: frequency for checking inference quality. Defaults to 50. + + Raises: + ValueError: in case the provided loss function is not supported. + """ + super().__init__(name=name, data=data) + self.position = position + self.input_key = f"{name}_{data['input']}" + self.target_key = f"{name}_{data['target']}" + + self.latent_size = latent_size + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.teacher_forcing = teacher_forcing + self.tokenizer = tokenizer + self.hidden_size_enc = hidden_size_enc + self.n_layers_enc = (n_layers_enc,) + self.hidden_size_dec = hidden_size_dec + self.n_layers_dec = (n_layers_dec,) + self.hidden_factor = (2 if bidirectional else 1) * n_layers_enc + + self.loss_function_name = loss_function.lower() + if self.loss_function_name not in LOSS_FACTORY: + raise ValueError( + f"loss_function={self.loss_function_name} not supported. Pick a valid one: {sorted(list(LOSS_FACTORY.keys()))}" + ) + self.loss_function = LOSS_FACTORY[self.loss_function_name] + + self.fc_mu = nn.Linear(self.hidden_factor * hidden_size_enc, self.latent_size) + self.fc_var = nn.Linear(self.hidden_factor * hidden_size_enc, self.latent_size) + self.encoder = RnnEncoder( + vocab_size=vocab_size, + embedding_size=embedding_size, + hidden_size=hidden_size_enc, + n_layers=n_layers_enc, + bidirectional=bidirectional, + ) + self.decoder = RnnDecoder( + vocab_size=vocab_size, + embedding_size=embedding_size, + hidden_size=hidden_size_dec, + n_layers=n_layers_dec, + latent_size=latent_size, + ) + + self.epoch_counter = 0 + self.klannealer = KLAnnealer( + kl_low=kl_low, + kl_high=kl_high, + n_epochs=kl_n_epochs, + start_epoch=kl_start_epoch, + ) + self.inference_check_frequency = inference_check_frequency + + def decode(self, z: Any, max_len: int = 127, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + max_len: maximum sequence length. Defaults to 127. + + Returns: + tuple with decoded texts and token indices. + """ + decoded_texts, token_indices = self.decoder.inference_direct( + z, self.encoder.embedding, self.tokenizer, max_len=max_len + ) + return decoded_texts, token_indices + + def _sampling_step(self, x: Any, *args, **kwargs) -> Any: + """Run a sampling step in the model. + + Args: + x: model input. + + Returns: + model sampling step output. + """ + x, input_embedding = self.encoder(x) + mu = self.fc_mu(x) + log_var = self.fc_var(x) + p, q, z = self.sample(mu, log_var) + return p, q, z, input_embedding + + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + _, _, z, _ = self._sampling_step(x) + return z + + def encode_decode(self, x: Any, max_len: int = 127, *args, **kwargs) -> Any: + """Encode and decode a sample. + + Args: + x: input sample. + max_len: maximum sequence length. Defaults to 127. + + Returns: + decoded sample. + """ + z = self.encode(x) + return self.decode(z, max_len=max_len) + + def inference(self, x: Any, *args, **kwargs) -> Any: # type:ignore + """Run the model in inference mode. + + Args: + x: sample. + + Returns: + generated output. + """ + max_len = x.size(1) + _, _, z, _ = self._sampling_step(x) + return self.decode(z, max_len=max_len) + + def _run_step(self, x: Any, *args, **kwargs) -> Any: + """Run a step in the model. + + Args: + x: model input. + + Returns: + model step output. + """ + p, q, z, input_embedding = self._sampling_step(x) + return z, self.decoder(z, input_embedding), p, q + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + x_out = target_data + # teacher forcing + if self.teacher_forcing: + x_out = x_out[:, 1:].long() + x = x[:, :-1] + + z, x_hat, p, q = self._run_step(x) + + x_hat = x_hat.view(-1, x_hat.size(-1)) + x_target = x_out.contiguous().view(-1) + + reconstruction_loss = self.loss_function(x_hat, x_target) + + log_qz = q.log_prob(z) + log_pz = p.log_prob(z) + kl_scaling_factor = self.klannealer(current_epoch) + kl = log_qz - log_pz + kl = kl.mean() + kl_scaled = kl * kl_scaling_factor + + loss = kl_scaled + reconstruction_loss + logs = { + "reconstruction_loss": reconstruction_loss, + "kl_scaled": kl_scaled, + "kl_unscaled": kl, + "kl_scaling_factor": kl_scaling_factor, + "loss": loss, + } + + return z, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + z, loss, logs = self.step( + input_data=input_data, + target_data=target_data, + device=device, + current_epoch=current_epoch, + ) + if current_epoch % self.inference_check_frequency == 0 and current_epoch > 0: + decoded_texts, token_indices = self.inference(x) + reconstructed_texts = 0 + decoded_splitted_texts = [ + text.split(self.tokenizer.eos_token, 1)[0] for text in decoded_texts + ] + for _, text in enumerate(decoded_splitted_texts): + if self.tokenizer.pad_token not in text: + reconstructed_texts += 1 + valid_percentage = float(reconstructed_texts) / x.size(0) + reconstructed_bits = torch.sum( + x[:, 1:] == token_indices[:, : x[:, 1:].size(1)] + ).item() + reconstructed_bits_percentage = reconstructed_bits / x.numel() + logs.update( + { + "reconstructed_bits": reconstructed_bits_percentage, + "validity": valid_percentage, + } + ) + + return z, loss, logs + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--dataset_type_{name}", type=str) + parser.add_argument(f"--position_{name}", type=int, nargs="+") + parser.add_argument(f"--build_vocab{name}", type=str2bool) + parser.add_argument(f"--vocab_file{name}", type=str) + parser.add_argument(f"--input_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--freeze_weights_{name}", type=str2bool) + parser.add_argument(f"--hidden_size_enc_{name}", type=int) + parser.add_argument(f"--hidden_size_dec_{name}", type=int) + parser.add_argument(f"--n_layers_enc_{name}", type=int) + parser.add_argument(f"--n_layers_dec_{name}", type=int) + parser.add_argument(f"--bidirectional_{name}", type=str2bool) + parser.add_argument(f"--latent_size_{name}", type=int) + parser.add_argument(f"--kl_low_{name}", type=float) + parser.add_argument(f"--kl_high_{name}", type=float) + parser.add_argument(f"--kl_n_epochs_{name}", type=int) + parser.add_argument(f"--kl_start_epoch_{name}", type=int) + parser.add_argument(f"--inference_check_frequency_{name}", type=int) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_trans/__init__.py b/src/gt4sd/frameworks/granular/ml/models/vae_trans/__init__.py new file mode 100644 index 000000000..bb492ea0f --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_trans/__init__.py @@ -0,0 +1,3 @@ +"""Initialize Transformer-based variational autoencoder module.""" + +from .core import VaeTrans # noqa: F401 diff --git a/src/gt4sd/frameworks/granular/ml/models/vae_trans/core.py b/src/gt4sd/frameworks/granular/ml/models/vae_trans/core.py new file mode 100644 index 000000000..e22769e68 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/models/vae_trans/core.py @@ -0,0 +1,405 @@ +"""VaeTrans implementation.""" + +from argparse import ArgumentParser +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .....torch import get_device_from_tensor +from ....arg_parser.utils import str2bool +from ..base_model import GranularEncoderDecoderModel +from ..loss import LOSS_FACTORY +from ..module import TransformerDecoder, TransformerEncoder +from ..utils import KLAnnealer + + +class VaeTrans(GranularEncoderDecoderModel): + """Transformer-based VAE with Gaussian Prior and approx posterior.""" + + def __init__( + self, + name: str, + position: int, + data: Dict[str, str], + vocab_size: int, + tokenizer, + hidden_size_enc: int = 256, + n_layers_enc: int = 2, + hidden_size_dec: int = 256, + n_layers_dec: int = 2, + kl_coeff: float = 0.1, + latent_size: int = 196, + feedforward_size: int = 512, + heads: int = 4, + dropout: float = 0.1, + bypass_bottleneck: bool = False, + seq_len: int = 127, + teacher_forcing: bool = True, + loss_function: str = "ce", + kl_low: float = 0.0, + kl_high: float = 0.1, + kl_n_epochs: int = 100, + kl_start_epoch: int = 0, + inference_check_frequency: int = 50, + **kwargs, + ): + """Construct VaeRnn. + + Args: + name: model name. + position: position of the model. + data: data name mappings. + vocab_size: size of the vocabulary. + tokenizer: tokenizer. + hidden_size_enc: encoder hidden size. Defaults to 256. + n_layers_enc: number of layers for the encoder. Defaults to 2. + hidden_size_dec: decoder hidden size. Defaults to 256. + n_layers_dec: number of layers for the decoder. Defaults to 2. + kl_coeff: KL coefficient. Defaults to 0.1. + latent_size: latent size. Defaults to 196. + feedforward_size: size of the feed forward network. Default to 512. + heads: number of heads. Defauls to 4. + dropout: dropout rate. Defaults to 0.1. + bypass_bottleneck: whether the bottleneck should be by passed. + Defaults to False. + seq_len: length of the sequence. Defaults to 127. + teacher_forcing: whether to teacher forcing. Defaults to True. + loss_function: loss function. Defaults to "ce". + kl_low: low KL weight. Defaults to 0.0. + kl_high: high KL weight. Defaults to 0.1. + kl_n_epochs: KL number of epochs. Defaults to 100. + kl_start_epoch: KL starting epoch. Defaults to 0. + inference_check_frequency: frequency for checking inference quality. Defaults to 50. + + Raises: + ValueError: in case the provided loss function is not supported. + """ + super().__init__(name=name, data=data) + self.position = position + self.input_key = f"{name}_{data['input']}" + self.target_key = f"{name}_{data['target']}" + self.kl_coeff = kl_coeff + self.latent_size = latent_size + self.vocab_size = vocab_size + self.tokenizer = tokenizer + self.teacher_forcing = teacher_forcing + self.seq_len = seq_len + + self.loss_function_name = loss_function.lower() + if self.loss_function_name not in LOSS_FACTORY: + raise ValueError( + f"loss_function={self.loss_function_name} not supported. Pick a valid one: {sorted(list(LOSS_FACTORY.keys()))}" + ) + self.loss_function = LOSS_FACTORY[self.loss_function_name] + + self.predict_len1 = nn.Linear(self.latent_size, self.latent_size * 2) + self.predict_len2 = nn.Linear(self.latent_size * 2, self.seq_len) + + self.encoder = TransformerEncoder( + hidden_size_enc, + feedforward_size, + seq_len, + dropout, + heads, + n_layers_enc, + vocab_size, + bypass_bottleneck, + ) + self.decoder = TransformerDecoder( + hidden_size_dec, + feedforward_size, + seq_len, + dropout, + heads, + n_layers_dec, + latent_size, + vocab_size, + bypass_bottleneck, + self.encoder.conv_output_shape, + ) + self.fc_mu = nn.Linear(self.encoder.conv_output_len, self.latent_size) + self.fc_var = nn.Linear(self.encoder.conv_output_len, self.latent_size) + + self.klannealer = KLAnnealer( + kl_low=kl_low, + kl_high=kl_high, + n_epochs=kl_n_epochs, + start_epoch=kl_start_epoch, + ) + self.inference_check_frequency = inference_check_frequency + + def forward(self, x: Any, tgt: torch.Tensor, *args, **kwrgs) -> Any: # type:ignore + """Forward pass in the model. + + Args: + x: model input. + tgt: target tensor + + Returns: + model output. + """ + x_out, _, _, z, _, _ = self._run_step(x, tgt) + return x_out, z + + def predict_mask_length(self, mem: torch.Tensor) -> Any: + """Predicts mask length from latent memory so mask can be re-created during inference. + + Args: + mem: latent memory. + + Returns: + mask length. + """ + pred_len = self.predict_len1(mem) + pred_len = self.predict_len2(pred_len) + pred_len = F.softmax(pred_len, dim=-1) + pred_len = torch.topk(pred_len, 1)[1] + return pred_len + + def _sampling_step(self, x: Any, *args, **kwargs) -> Any: + """Run a sampling step in the model. + + Args: + x: model input. + + Returns: + model sampling step output. + """ + src_mask = (x != self.tokenizer.pad_token_id).unsqueeze(-2) + x = self.encoder(x, src_mask) + + mu = self.fc_mu(x) + log_var = self.fc_var(x) + return self.sample(mu, log_var) + + def encode(self, x: Any, *args, **kwargs) -> Any: + """Encode a sample. + + Args: + x: input sample. + + Returns: + latent encoding. + """ + _, _, z = self._sampling_step(x) + return z + + def decode(self, z: Any, *args, **kwargs) -> Any: + """Decode a latent space point. + + Args: + z: latent point. + + Returns: + tuple with decoded texts and token indices. + """ + mask_lens = self.predict_mask_length(z) + decoded_texts, token_indices = self.decoder.inference_direct( + z, mask_lens, self.tokenizer + ) + return decoded_texts, token_indices + + def encode_decode(self, x: Any, *args, **kwargs) -> Any: + """Encode and decode a sample. + + Args: + x: input sample. + + Returns: + decoded sample. + """ + z = self.encode(x) + _, token_indices = self.decode(z, x.device) + return token_indices + + def inference( # type:ignore + self, x: Any, *args, **kwargs + ) -> Any: + """Run the model in inference mode. + + Args: + x: sample. + + Returns: + generated output. + """ + device = get_device_from_tensor(x) + z = self.encode(x) + decoded_texts, token_indices = self.decode(z, device) + return decoded_texts, token_indices + + def _run_step(self, x: Any, tgt: torch.Tensor) -> Any: # type:ignore + """Run a step in the model. + + Args: + x: model input. + tgt: target tensor + + Returns: + model step output. + """ + src_mask = (x != self.tokenizer.pad_token_id).unsqueeze(-2) + tgt_mask = (tgt != self.tokenizer.pad_token_id).unsqueeze(-2) + attn_shape = (1, tgt.size(-1), tgt.size(-1)) + subsequent_mask = ( + torch.from_numpy(np.triu(np.ones(attn_shape), k=1).astype("uint8")) == 0 + ) + tgt_mask = tgt_mask & torch.autograd.Variable( + subsequent_mask.type_as(tgt_mask.data) + ) + x = self.encoder(x, src_mask) + mu = self.fc_mu(x) + log_var = self.fc_var(x) + mask_lens = self.predict_len1(mu) + mask_lens = self.predict_len2(mask_lens) + true_len = src_mask.sum(dim=-1).contiguous().view(-1) + p, q, z = self.sample(mu, log_var) + x_out = self.decoder(tgt, z, src_mask, tgt_mask) + return x_out, p, q, z, mask_lens, true_len + + def step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Tuple[Any, Any, Any]: + """Training step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + x_target = target_data + + if self.teacher_forcing: + x_tgt_in = x_target[:, :-1] + x_tgts_out = x_target.long()[:, 1:] + else: + x_tgt_in = x_target + x_tgts_out = x_target.long() + + x_pred_out, p, q, z, pred_len, true_len = self._run_step(x, x_tgt_in) + + len_loss = self.loss_function(pred_len, true_len) + + x_pred_out = x_pred_out.contiguous().view(-1, x_pred_out.size(2)) + x_tgts_out = x_tgts_out.contiguous().view(-1) + reconstruction_loss = self.loss_function(x_pred_out, x_tgts_out) + + log_qz = q.log_prob(z) + log_pz = p.log_prob(z) + kl_scaling_factor = self.klannealer(current_epoch) + kl = log_qz - log_pz + kl = kl.mean() + kl_scaled = kl * kl_scaling_factor + + loss = kl_scaled + reconstruction_loss + len_loss + logs = { + "reconstruction_loss": reconstruction_loss, + "kl_scaled": kl_scaled, + "kl_unscaled": kl, + "kl_scaling_factor": kl_scaling_factor, + "len_loss": len_loss, + "loss": loss, + } + return z, loss, logs + + def val_step( + self, + input_data: Any, + target_data: Any, + device: str = "cpu", + current_epoch: int = 0, + *args, + **kwargs, + ) -> Any: + """Validation step for the model. + + Args: + input_data: input for the step. + target_data: target for the step. + device: string representing the device to use. Defaults to "cpu". + current_epoch: current epoch. Defaults to 0. + + Returns: + a tuple containing the step output, the loss and the logs for the module. + """ + x = input_data + z, loss, logs = self.step(input_data, target_data, device, current_epoch) + + if current_epoch % self.inference_check_frequency == 0 and current_epoch > 0: + decoded_texts, token_indices = self.inference(x) + reconstructed_texts = 0 + decoded_splitted_texts = [ + text.split(self.tokenizer.eos_token, 1)[0] for text in decoded_texts + ] + for _, text in enumerate(decoded_splitted_texts): + if self.tokenizer.pad_token not in text: + reconstructed_texts += 1 + valid_percentage = float(reconstructed_texts) / x.size(0) + reconstructed_bits = torch.sum(x == token_indices).item() + reconstructed_bits_percentage = reconstructed_bits / x.numel() + logs.update( + { + "reconstructed_bits": reconstructed_bits_percentage, + "validity": valid_percentage, + } + ) + return z, loss, logs + + @staticmethod + def add_model_specific_args( + parent_parser: ArgumentParser, name: str, *args, **kwargs + ) -> ArgumentParser: + """Adding to a parser model specific arguments. + + Args: + parent_parser: patent parser. + name: model name. + + Returns: + updated parser. + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + parser.add_argument(f"--data_path_{name}", type=str) + parser.add_argument(f"--data_file_{name}", type=str) + parser.add_argument(f"--dataset_type_{name}", type=str) + parser.add_argument(f"--position_{name}", type=int, nargs="+") + parser.add_argument(f"--build_vocab{name}", type=str2bool) + parser.add_argument(f"--vocab_file{name}", type=str) + parser.add_argument(f"--input_{name}", type=str) + parser.add_argument(f"--target_{name}", type=str) + parser.add_argument(f"--checkpoint_path_{name}", type=str) + parser.add_argument(f"--checkpoint_model_name_{name}", type=str) + parser.add_argument(f"--start_from_checkpoint_{name}", type=str2bool) + parser.add_argument(f"--freeze_weights_{name}", type=str2bool) + parser.add_argument(f"--hidden_size_enc_{name}", type=int) + parser.add_argument(f"--hidden_size_dec_{name}", type=int) + parser.add_argument(f"--n_layers_enc_{name}", type=int) + parser.add_argument(f"--n_layers_dec_{name}", type=int) + parser.add_argument(f"--bidirectional_{name}", type=str2bool) + parser.add_argument(f"--latent_size_{name}", type=int) + parser.add_argument("--feedforward_size", type=int) + parser.add_argument("--heads", type=int) + parser.add_argument("--dropout", type=float) + parser.add_argument("--bypass_bottleneck", type=str2bool) + parser.add_argument(f"--kl_low_{name}", type=float) + parser.add_argument(f"--kl_high_{name}", type=float) + parser.add_argument(f"--kl_n_epochs_{name}", type=int) + parser.add_argument(f"--kl_start_epoch_{name}", type=int) + parser.add_argument(f"--inference_check_frequency_{name}", type=int) + + return parser diff --git a/src/gt4sd/frameworks/granular/ml/module.py b/src/gt4sd/frameworks/granular/ml/module.py new file mode 100644 index 000000000..c0d2c1793 --- /dev/null +++ b/src/gt4sd/frameworks/granular/ml/module.py @@ -0,0 +1,262 @@ +"""Model combiner module.""" + +import os +from typing import Any, Callable, Dict, List, Tuple, cast + +import pandas as pd +import pytorch_lightning as pl +import torch + +from .models import GranularBaseModel, GranularEncoderDecoderModel +from .models.model_builder import building_models, define_latent_models_input_size + + +class GranularModule(pl.LightningModule): + """Module from granular.""" + + def __init__( + self, + architecture_autoencoders: List[Dict[str, Any]], + architecture_latent_models: List[Dict[str, Any]], + lr: float = 1e-4, + test_output_path: str = "./test", + **kwargs, + ) -> None: + """Construct GranularModule. + + Args: + architecture_autoencoders: list of autoencoder architecture configurations. + architecture_latent_models: list of latent model architecture configurations. + lr: learning rate for Adam optimizer. Defaults to 1e-4. + test_output_path: path where to save latent encodings and predictions for the test set + when an epoch ends. Defaults to a a folder called "test" in the current working directory. + """ + super().__init__() + self.save_hyperparameters() + + architecture_latent_models = define_latent_models_input_size( + architecture_autoencoders, architecture_latent_models + ) + self.architecture_autoencoders = architecture_autoencoders + self.architecture_latent_models = architecture_latent_models + + self.autoencoders = building_models(self.architecture_autoencoders) + self.latent_models = building_models(self.architecture_latent_models) + + self.lr = lr + self.test_output_path = test_output_path + for model in self.autoencoders + self.latent_models: + setattr(self, model.name, model) + + def _autoencoder_step( + self, batch: Any, model: GranularEncoderDecoderModel, model_step_fn: Callable + ) -> Tuple[Any, Any, Any]: + """Autoencoder module forward pass. + + Args: + batch: batch representation. + model: a module. + model_step_fn: callable for the step. + + Returns: + a tuple containing the latent representation, the loss and the logs for the module. + """ + return model_step_fn( + input_data=batch[model.input_key], + target_data=batch[model.target_key], + device=self.device, + current_epoch=self.current_epoch, + ) + + def _latent_step( + self, + batch: Any, + model: GranularBaseModel, + model_step_fn: Callable, + z: Dict[int, Any], + ) -> Tuple[Any, Any, Any]: + """Latent module forward pass. + + Args: + batch: batch representation. + model: a module. + model_step_fn: callable for the step. + z: latent encodings. + + Returns: + a tuple containing the latent step ouput, the loss and the logs for the module. + """ + z_model_input = torch.cat( + [ + torch.squeeze(z[pos]) if len(z[pos].size()) == 3 else z[pos] + for pos in model.from_position + ], + dim=1, + ) + return model_step_fn( + input_data=z_model_input, + target_data=batch[model.target_key], + device=self.device, + current_epoch=self.current_epoch, + ) + + def training_step( # type:ignore + self, batch: Any, *args, **kwargs + ) -> Dict[str, Any]: + """Training step implementation. + + Args: + batch: batch representation. + + Returns: + loss and logs. + """ + loss = 0.0 + z = dict() + logs = dict() + + for model in self.autoencoders: + z[model.position], loss_model, logs_model = self._autoencoder_step( + batch=batch, + model=cast(GranularEncoderDecoderModel, model), + model_step_fn=model.step, + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + for model in self.latent_models: + _, loss_model, logs_model = self._latent_step( + batch=batch, model=model, model_step_fn=model.step, z=z + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + logs.update({"total_loss": loss}) + self.log_dict( + {f"train/{k}": v for k, v in logs.items()}, on_epoch=False, prog_bar=False + ) + logs_epoch = {f"train_epoch/{k}": v for k, v in logs.items()} + logs_epoch["step"] = self.current_epoch + self.log_dict(logs_epoch, on_step=False, on_epoch=True, prog_bar=False) + + return {"loss": loss, "logs": logs} + + def validation_step( # type:ignore + self, batch: Any, *args, **kwargs + ) -> Dict[str, Any]: + """Validation step implementation. + + Args: + batch: batch representation. + + Returns: + loss and logs. + """ + loss = 0.0 + z = dict() + logs = dict() + + for model in self.autoencoders: + z[model.position], loss_model, logs_model = self._autoencoder_step( + batch=batch, + model=cast(GranularEncoderDecoderModel, model), + model_step_fn=model.val_step, + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + for model in self.latent_models: + _, loss_model, logs_model = self._latent_step( + batch=batch, model=model, model_step_fn=model.val_step, z=z + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + logs.update({"total_loss": loss}) + self.log_dict( + {f"val/{k}": v for k, v in logs.items()}, on_epoch=True, prog_bar=True + ) + + return {"loss": loss, "logs": logs} + + def test_step( # type:ignore + self, batch: Any, batch_idx: int, *args, **kwargs + ) -> Dict[str, Any]: + """Testing step implementation. + + Args: + batch: batch representation. + batch_idx: batch index, unused. + + Returns: + loss, logs, and latent encodings. + """ + loss = 0.0 + z = dict() + logs = dict() + + for model in self.autoencoders: + z[model.position], loss_model, logs_model = self._autoencoder_step( + batch=batch, + model=cast(GranularEncoderDecoderModel, model), + model_step_fn=model.val_step, + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + for model in self.latent_models: + _, loss_model, logs_model = self._latent_step( + batch=batch, model=model, model_step_fn=model.val_step, z=z + ) + logs.update({model.name + f"/{k}": v for k, v in logs_model.items()}) + loss += loss_model + + logs.update({"total_loss": loss}) + self.log_dict( + {f"val/{k}": v for k, v in logs.items()}, on_epoch=True, prog_bar=True + ) + return {"loss": loss, "logs": logs, "z": z} + + def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type:ignore + """Callback called at the end of an epoch on test outputs. + + Dump encodings and targets for the test set. + + Args: + outputs: outputs for test batches. + """ + z = {} + targets = {} + z_keys = [key for key in outputs[0]["z"]] + targets_keys = [key for key in outputs[0]["targets"]] + for key in z_keys: + z[key] = ( + torch.cat( + [torch.squeeze(an_output["z"][key]) for an_output in outputs], dim=0 + ) + .detach() + .cpu() + .numpy() + ) + + for key in targets_keys: + targets[key] = ( + torch.cat( + [torch.squeeze(an_output["targets"][key]) for an_output in outputs], + dim=0, + ) + .detach() + .cpu() + .numpy() + ) + + pd.to_pickle(z, f"{self.test_output_path}{os.path.sep}z_build.pkl") + pd.to_pickle(targets, f"{self.test_output_path}{os.path.sep}targets.pkl") + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers. + + Returns: + an optimizer, currently only Adam is supported. + """ + return torch.optim.Adam(self.parameters(), lr=self.lr) diff --git a/src/gt4sd/frameworks/granular/tests/__init__.py b/src/gt4sd/frameworks/granular/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/frameworks/granular/tests/test_tokenizer.py b/src/gt4sd/frameworks/granular/tests/test_tokenizer.py new file mode 100644 index 000000000..0399bd472 --- /dev/null +++ b/src/gt4sd/frameworks/granular/tests/test_tokenizer.py @@ -0,0 +1,161 @@ +"""Tests for granular tokenizer.""" + +from gt4sd.frameworks.granular.tokenizer.tokenizer import ( + SelfiesTokenizer, + SmilesTokenizer, +) + + +def test_tokenization(): + smiles = [ + "c1ccccc1", + "c1ccc(CP(c2ccccc2)c2ccccc2)cc1.CCCCN1[C]N(Cc2ccccc2)c2ccccc21.[Ag]", + ] + + def _test_tokenizer(tokenizer_type, tokens_groundtruth): + tokenizer = tokenizer_type("test", smiles=smiles) + tokens = tokenizer.tokenize(smiles[1]) + assert tokens_groundtruth == tokens + assert [ + tokenizer.vocab[token] for token in tokens + ] == tokenizer.convert_tokens_to_ids(tokenizer.tokenize(smiles[1])) + assert 2 == len( + [ + tokenizer.convert_tokens_to_ids(tokenizer.tokenize(a_smiles)) + for a_smiles in smiles + ] + ) + + _test_tokenizer( + SelfiesTokenizer, + [ + "[c]", + "[c]", + "[c]", + "[c]", + "[Branch1_3]", + "[=S]", + "[C]", + "[P]", + "[Branch1_3]", + "[Branch2_2]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[Ring1]", + "[Branch1_1]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[Ring1]", + "[Branch1_1]", + "[c]", + "[c]", + "[Ring1]", + "[#C]", + "[.]", + "[C]", + "[C]", + "[C]", + "[C]", + "[N]", + "[Cexpl]", + "[N]", + "[Branch1_3]", + "[Branch2_3]", + "[C]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[Ring1]", + "[Branch1_1]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[c]", + "[Ring1]", + "[Branch1_1]", + "[Ring1]", + "[=N]", + "[.]", + "[Agexpl]", + ], + ) + + _test_tokenizer( + SmilesTokenizer, + [ + "c", + "1", + "c", + "c", + "c", + "(", + "C", + "P", + "(", + "c", + "2", + "c", + "c", + "c", + "c", + "c", + "2", + ")", + "c", + "2", + "c", + "c", + "c", + "c", + "c", + "2", + ")", + "c", + "c", + "1", + ".", + "C", + "C", + "C", + "C", + "N", + "1", + "[C]", + "N", + "(", + "C", + "c", + "2", + "c", + "c", + "c", + "c", + "c", + "2", + ")", + "c", + "2", + "c", + "c", + "c", + "c", + "c", + "2", + "1", + ".", + "[Ag]", + ], + ) diff --git a/src/gt4sd/frameworks/granular/tokenizer/__init__.py b/src/gt4sd/frameworks/granular/tokenizer/__init__.py new file mode 100644 index 000000000..c545ed9d2 --- /dev/null +++ b/src/gt4sd/frameworks/granular/tokenizer/__init__.py @@ -0,0 +1,9 @@ +"""Tokenization module.""" + +from .tokenizer import ( # noqa: F401 + TOKENIZER_FACTORY, + GenericTokenizer, + SelfiesTokenizer, + SmilesTokenizer, + Tokenizer, +) diff --git a/src/gt4sd/frameworks/granular/tokenizer/tokenizer.py b/src/gt4sd/frameworks/granular/tokenizer/tokenizer.py new file mode 100644 index 000000000..32dd76e53 --- /dev/null +++ b/src/gt4sd/frameworks/granular/tokenizer/tokenizer.py @@ -0,0 +1,509 @@ +"""Tokenizers implementations.""" + +import collections +import logging +import os +from typing import Dict, Iterable, List, Type + +import regex as re +import selfies as sf +from pytoda.smiles.processing import tokenize_selfies + +SMI_REGEX_PATTERN = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +def selfies_alphabet() -> List[str]: + """Legacy selfies 0.2.4 alphabet method. + + Adapted from: https://github.com/aspuru-guzik-group/selfies/blob/84122855ae76a928e1cb7d58796b8b47385a4359/selfies/selfies.py#L4. + + Returns: + SELFIES list of tokens. + """ + alphabet = [ + "[Branch1_1]", + "[Branch1_2]", + "[Branch1_3]", + "[Ring1]", + "[Branch2_1]", + "[Branch2_2]", + "[Branch2_3]", + "[Ring2]", + "[Branch3_1]", + "[Branch3_2]", + "[Branch3_3]", + "[Ring3]", + "[O]", + "[=O]", + "[N]", + "[=N]", + "[C]", + "[=C]", + "[#C]", + "[S]", + "[=S]", + "[P]", + "[F]", + "[C@Hexpl]", + "[C@@Hexpl]", + "[C@expl]", + "[C@@expl]", + "[H]", + "[NHexpl]", + ] + return alphabet + + +def load_vocab(vocab_file: str) -> Dict[str, int]: + """Loads a vocabulary file into a dictionary. + + Args: + vocab_file: vocabulary file. + + Returns: + vocabulary mapping tokens to indices. + """ + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +class BasicTokenizer: + """Basic tokenizer.""" + + def __init__( + self, + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a BasicSmilesTokenizer. + + Args: + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + self.pad_token = pad_token + self.sos_token = sos_token + self.eos_token = eos_token + self.unk_token = unk_token + + def tokenize(self, text: str) -> List[str]: + """Tokenize input text. + + Args: + text: text to tokenize. + + Returns: + list of tokens. + """ + return list(text) + + def build_vocab(self, smiles: Iterable[str], vocab_file: str) -> List[str]: + """Build and save a vocabulary given a SMILES list. + + Args: + smiles: iterable of SMILES. + vocab_file: path to a file where the vocabulary is saved. + + Returns: + a list of all tokens in the vocabulary. + """ + tokens = set([self.pad_token, self.sos_token, self.eos_token, self.unk_token]) + + for smile in smiles: + tokens_temp = self.tokenize(smile) + + for token in tokens_temp: + tokens.add(token) + + tokens_list = sorted(list(tokens)) + + with open(vocab_file, "w") as f: + for item in tokens_list: + f.write(f"{item}{os.linesep}") + + return tokens_list + + +class BasicSmilesTokenizer(BasicTokenizer): + """Basic SMILES tokenizer.""" + + def __init__( + self, + regex_pattern: str = SMI_REGEX_PATTERN, + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a BasicSmilesTokenizer. + + Args: + regex_pattern: regex pattern. Defaults to SMI_REGEX_PATTERN. + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(self.regex_pattern) + super().__init__( + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ) + + def tokenize(self, text: str) -> List[str]: + """Tokenize input text. + + Args: + text: text to tokenize. + + Returns: + list of tokens. + """ + return [token for token in self.regex.findall(text)] + + +class BasicSelfiesTokenizer(BasicTokenizer): + """Basic SELFIES tokenizer.""" + + def __init__( + self, + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a BasicSelfiesTokenizer. + + Args: + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + self.pad_token = pad_token + self.sos_token = sos_token + self.eos_token = eos_token + self.unk_token = unk_token + + def smiles_to_selfies(self, smiles: Iterable[str]) -> List[str]: + """Convert a list of SMILES into SELFIES. + + Args: + smiles: a list of SMILES. + + Returns: + a list of SELFIES. + """ + return [sf.encoder(a_smiles) for a_smiles in smiles] + + def tokenize(self, text: str) -> List[str]: + """Tokenize input text. + + Args: + text: text to tokenize. + + Returns: + list of tokens. + """ + return tokenize_selfies(sf.encoder(text)) + + def build_vocab(self, smiles: Iterable[str], vocab_file: str) -> List[str]: + """Build and save a vocabulary given a SMILES list. + + Args: + smiles: iterable of SMILES. + vocab_file: path to a file where the vocabulary is saved. + + Returns: + a list of all tokens in the vocabulary. + """ + selfies = self.smiles_to_selfies(smiles) + tokens = set( + [self.pad_token, self.sos_token, self.eos_token, self.unk_token, "[.]"] + + selfies_alphabet() + ) + for a_selfies in selfies: + tokens = tokens | set(tokenize_selfies(a_selfies)) + + tokens_list = sorted(list(tokens)) + + with open(vocab_file, "w") as f: + for item in tokens_list: + f.write(f"{item}{os.linesep}") + + return tokens_list + + +class Tokenizer: + """Tokenizer that can build a vocabulary on the fly.""" + + def __init__( + self, + vocab_file: str, + basic_tokenizer: BasicTokenizer = BasicTokenizer(), + smiles: List[str] = [], + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + + """Constructs a Tokenizer. + + Args: + vocab_file: path to vocabulary file. If the file is not present, the provided SMILES list + is used to generate one. + basic_tokenizer: a basic tokenizer. Defaults to BasicTokenizer character tokenizer. + smiles: list of smiles. Default to empty list, used only if the vocabulary file does not exist. + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + self.basic_tokenizer = basic_tokenizer + self.pad_token = pad_token + self.sos_token = sos_token + self.eos_token = eos_token + self.unk_token = unk_token + + # load or build vocab + if os.path.isfile(vocab_file) and len(smiles) == 0: + logger.info(f"load vocab from: {vocab_file}") + self.vocab = load_vocab(vocab_file) + else: + logger.info("build and vocabulary") + self.basic_tokenizer.build_vocab(smiles, vocab_file) + logger.info(f"saved vocabulary: {vocab_file}") + self.vocab = load_vocab(vocab_file) + self.vocab_ids = {token: index for token, index in self.vocab.items()} + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()] + ) + + self.pad_token_id = self.vocab.get(pad_token, self.vocab[self.unk_token]) + self.sos_token_id = self.vocab.get(sos_token, self.vocab[self.unk_token]) + + @property + def vocab_size(self) -> int: + """Size of the vocabulary. + + Returns: + vocabulary file. + """ + return len(self.vocab) + + @property + def vocab_list(self) -> List[str]: + """Return vocabulary tokens. + + Returns: + all tokens from the vocabulary. + """ + return list(self.vocab.keys()) + + def tokenize(self, text: str) -> List[str]: + """Tokenize a given text. + + Args: + text: text to tokenize. + + Returns: + list of tokens. + """ + return [token for token in self.basic_tokenizer.tokenize(text)] + + def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: + """Convert tokens to indices. + + Args: + tokens: list of tokens. + + Returns: + list of indices. + """ + ids = [] + for token in tokens: + ids.append(self.convert_token_to_id(token)) + return ids + + def convert_token_to_id(self, token: str) -> int: + """Convert token to index. + + Args: + token: a token. + + Returns: + index corresponding to the input token. Unknown token index if the input + token is not present in the vocabulary. + """ + return self.vocab.get(token, self.vocab[self.unk_token]) + + def convert_id_to_token(self, index: int) -> str: + """Convert index to token. + + Args: + index: an index. + + Returns: + token corresponding to the input index. Unknown token if the input + index is not found. + """ + return self.ids_to_tokens.get(index, self.unk_token) + + def add_padding_tokens( + self, token_ids: List[int], length: int, right: bool = True + ) -> List[int]: + """Add padding token indices to the provided token indices. + + Args: + token_ids: token indices. + length: length of the sequence. + right: wheter the padding is performed on the right. Defaults to True, if False + the padding happens on the left. + + Returns: + the padded sequence. + """ + padding = [self.pad_token_id] * (length - len(token_ids)) + if right: + return token_ids + padding + else: + return padding + token_ids + + +class GenericTokenizer(Tokenizer): + """Generic tokenizer that can build a vocabulary on the fly.""" + + def __init__( + self, + vocab_file: str, + smiles: List[str] = [], + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a GenericTokenizer. + + Args: + vocab_file: path to vocabulary file. If the file is not present, the provided SMILES list + is used to generate one. + smiles: list of smiles. Default to empty list, used only if the vocabulary file does not exist. + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + super().__init__( + vocab_file=vocab_file, + basic_tokenizer=BasicTokenizer( + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ), + smiles=smiles, + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ) + + +class SmilesTokenizer(Tokenizer): + """SMILES tokenizer that can build a vocabulary on the fly.""" + + def __init__( + self, + vocab_file: str, + smiles: List[str] = [], + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a SmilesTokenizer. + + Args: + vocab_file: path to vocabulary file. If the file is not present, the provided SMILES list + is used to generate one. + smiles: list of smiles. Default to empty list, used only if the vocabulary file does not exist. + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + super().__init__( + vocab_file=vocab_file, + basic_tokenizer=BasicSmilesTokenizer( + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ), + smiles=smiles, + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ) + + +class SelfiesTokenizer(Tokenizer): + """SELFIES tokenizer that can build a vocabulary on the fly.""" + + def __init__( + self, + vocab_file: str, + smiles: List[str] = [], + pad_token: str = "", + sos_token: str = "", + eos_token: str = "", + unk_token: str = "", + ) -> None: + """Constructs a SelfiesTokenizer. + + Args: + vocab_file: path to vocabulary file. If the file is not present, the provided SMILES list + is used to generate one. + smiles: list of smiles. Default to empty list, used only if the vocabulary file does not exist. + pad_token: padding token. Defaults to ''. + sos_token: start of sequence token. Defaults to ''. + eos_token: end of sequence token. Defaults to ''. + unk_token: unknown token. Defaults to ''. + """ + super().__init__( + vocab_file=vocab_file, + basic_tokenizer=BasicSelfiesTokenizer( + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ), + smiles=smiles, + pad_token=pad_token, + sos_token=sos_token, + eos_token=eos_token, + unk_token=unk_token, + ) + + +TOKENIZER_FACTORY: Dict[str, Type[Tokenizer]] = { + "generic": GenericTokenizer, + "smiles": SmilesTokenizer, + "selfies": SelfiesTokenizer, +} diff --git a/src/gt4sd/frameworks/torch/__init__.py b/src/gt4sd/frameworks/torch/__init__.py new file mode 100644 index 000000000..04b9940ea --- /dev/null +++ b/src/gt4sd/frameworks/torch/__init__.py @@ -0,0 +1,47 @@ +"""Generic utils for pytorch.""" + +from typing import Optional, Union + +import torch + + +def get_device() -> torch.device: + """ + Get device dynamically. + """ + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def device_claim(device: Optional[Union[torch.device, str]] = None) -> torch.device: + """ + Satidfy a device claim. + + Args: + device: device where the inference + is running either as a dedicated class or a string. If not provided is inferred. + + Returns: + torch.device: the claimed device or a default one. + """ + if isinstance(device, str): + device = torch.device(device) + device = ( + get_device() + if (device is None or not isinstance(device, torch.device)) + else device + ) + return device + + +def get_device_from_tensor(tensor: torch.Tensor) -> torch.device: + """Get the device from a tensor. + + Args: + tensor: a tensor. + + Returns: + the device. + """ + device_id = tensor.get_device() + device = "cpu" if device_id < 0 else f"cuda:{device_id}" + return device_claim(device) diff --git a/src/gt4sd/frameworks/torch/vae.py b/src/gt4sd/frameworks/torch/vae.py new file mode 100644 index 000000000..9e7f71357 --- /dev/null +++ b/src/gt4sd/frameworks/torch/vae.py @@ -0,0 +1,17 @@ +"""pytorch utils for VAEs.""" + +import torch + + +def reparameterize(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """ + Applies reparametrization trick to obtain sample from latent space. + + Args: + mu: the latent means of shape batch_size x latent_size. + logvar: latent log variances, shape batch_size x latent_size. + + Returns: + torch.Tensor: sampled Z from the latent distribution. + """ + return torch.randn_like(mu).mul_(torch.exp(0.5 * logvar)).add_(mu) # type:ignore diff --git a/src/gt4sd/py.typed b/src/gt4sd/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/s3.py b/src/gt4sd/s3.py new file mode 100644 index 000000000..2e487a9e8 --- /dev/null +++ b/src/gt4sd/s3.py @@ -0,0 +1,165 @@ +"""S3 storage utilities.""" + +import logging +import os +from typing import List, Optional, Set + +from minio import Minio + +from .exceptions import S3SyncError + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class GT4SDS3Client: + def __init__( + self, host: str, access_key: str, secret_key: str, secure: bool = True + ) -> None: + """ + Construct an S3 client. + + Args: + host: s3 host address. + access_key: s3 access key. + secret_key: s3 secret key. + secure: whether the connection is secure or not. Defaults + to True. + """ + self.host = host + self.access_key = access_key + self.secret_key = secret_key + self.secure = secure + self.client = Minio( + self.host, + access_key=self.access_key, + secret_key=self.secret_key, + secure=self.secure, + ) + + def list_bucket_names(self) -> List[str]: + """ + List all available s3 bucket names. + + Returns: + List[str]: list with bucket names. + """ + return [bucket.name for bucket in self.client.list_buckets()] + + def list_object_names(self, bucket: str, prefix: Optional[str] = None) -> List[str]: + """ + List all available objects (recursive) in the given bucket based on a given prefix. + + Args: + bucket: bucket name to search for objects. + prefix: prefix for objects in the bucket. + Defaults to None, a.k.a., no prefix filter. + + Returns: + List[str]: list with object names. + """ + return [ + s3_object.object_name + for s3_object in self.client.list_objects( + bucket_name=bucket, prefix=prefix, recursive=True + ) + ] + + def list_directories(self, bucket: str, prefix: Optional[str] = None) -> Set[str]: + """ + List all available "directories" in the given bucket based on a given prefix. + + Args: + bucket: bucket name to search for objects. + prefix: prefix for objects in the bucket. + Defaults to None, a.k.a., no prefix filter. + Needs to be a "directory" itself. + + Returns: + List[str]: list with directory names. + """ + if prefix: + prefix = prefix + "/" if prefix[-1] != "/" else prefix + return set( + s3_object.object_name[len(prefix) if prefix else 0 : -1] + for s3_object in self.client.list_objects( + bucket_name=bucket, prefix=prefix, recursive=False + ) + if s3_object.object_name[-1] == "/" + ) + + def sync_folder( + self, bucket: str, path: str, prefix: Optional[str] = None, force: bool = False + ) -> None: + """Sync an entire folder from S3 recursively and save it under the given path. + + If :obj:`prefix` is given, every file under ``prefix/`` in S3 will be saver under ``path/`` in disk (i.e. + ``prefix/`` is replaced by ``path/``). + + + Args: + bucket: bucket name to search for objects. + path: path to save the objects in disk. + prefix: prefix for objects in the bucket. Defaults to None, a.k.a., no prefix filter. + force: force download even if a file with the same name is present. Defaults to False. + """ + if not os.path.exists(path): + logger.warning(f"path {path} does not exist, creating it...") + os.makedirs(path) + s3_objects = self.client.list_objects( + bucket_name=bucket, prefix=prefix, recursive=True + ) + for s3_object in s3_objects: + object_name = s3_object.object_name + object_name_stripped_prefix = ( + os.path.relpath(object_name, prefix) if prefix else object_name + ) + filepath = os.path.join(path, object_name_stripped_prefix) + # check for existence + do_download = not os.path.exists(filepath) + if do_download or force: + logger.info(f"downloading file {object_name} in {filepath}") + self.client.fget_object( + bucket_name=bucket, object_name=object_name, file_path=filepath + ) + + +def sync_folder_with_s3( + host: str, + access_key: str, + secret_key: str, + bucket: str, + folder_path: str, + prefix: Optional[str] = None, + secure: bool = True, +) -> None: + """ + Sync the cache with the S3 remote storage. + + Args: + host: s3 host address. + access_key: s3 access key. + secret_key: s3 secret key. + bucket: bucket name to search for objects. + folder_path: folder path. + prefix: prefix for objects in the bucket. Defaults to None, a.k.a., no prefix filter. + secure: whether the connection is secure or not. Defaults + to True. + + Raises: + S3SyncError: in case of S3 syncing errors. + """ + path = os.path.join(folder_path, prefix) if prefix else folder_path + try: + client = GT4SDS3Client( + host=host, access_key=access_key, secret_key=secret_key, secure=secure + ) + logger.info("starting syncing") + client.sync_folder(bucket=bucket, path=path, prefix=prefix) + logger.info("syncing complete") + except Exception: + logger.exception("generic syncing error") + raise S3SyncError( + "CacheSyncingError", + f"error in syncing path={path} with host={host} access_key={access_key} secret_key={secret_key} secure={secure} bucket={bucket}", + ) diff --git a/src/gt4sd/tests/__init__.py b/src/gt4sd/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/tests/test_configuration.py b/src/gt4sd/tests/test_configuration.py new file mode 100644 index 000000000..0a653aa84 --- /dev/null +++ b/src/gt4sd/tests/test_configuration.py @@ -0,0 +1,35 @@ +"""Configuration tests.""" + +import os + +from gt4sd.configuration import ( + GT4SDConfiguration, + get_algorithm_subdirectories_in_cache, + get_algorithm_subdirectories_with_s3, +) + +gt4sd_configuration_instance = GT4SDConfiguration.get_instance() + + +def test_default_local_cache_path(): + if "GT4SD_LOCAL_CACHE_PATH" not in os.environ: + assert os.path.dirname( + gt4sd_configuration_instance.gt4sd_local_cache_path + ) == os.path.expanduser("~") + assert ( + os.path.basename(gt4sd_configuration_instance.gt4sd_local_cache_path) + == ".gt4sd" + ) + else: + assert ( + gt4sd_configuration_instance.gt4sd_local_cache_path + == os.environ["GT4SD_LOCAL_CACHE_PATH"] + ) + + +def test_get_algorithm_subdirectories_with_s3(): + assert isinstance(get_algorithm_subdirectories_with_s3(), set) + + +def test_get_algorithm_subdirectories_in_cache(): + assert isinstance(get_algorithm_subdirectories_in_cache(), set) diff --git a/src/gt4sd/tests/test_exceptions.py b/src/gt4sd/tests/test_exceptions.py new file mode 100644 index 000000000..7aa8ec05d --- /dev/null +++ b/src/gt4sd/tests/test_exceptions.py @@ -0,0 +1,37 @@ +"""Exceptions tests.""" + +import pytest + +from gt4sd.exceptions import InvalidAlgorithmConfiguration, InvalidItem, S3SyncError + + +def test_s3_sync_error(): + error = S3SyncError("GenericSyncError", "my message") + assert error.type == "S3SyncError" + assert error.title == "GenericSyncError" + assert error.detail == "my message" + with pytest.raises(RuntimeError): + str(error) == error.detail + raise error + + +def test_invalid_item(): + error = InvalidItem("GenericInvaliItemError", "my message") + assert error.type == "InvalidItem" + assert error.title == "GenericInvaliItemError" + assert error.detail == "my message" + with pytest.raises(ValueError): + str(error) == error.detail + raise error + + +def test_invalid_algorithm_configuration(): + error = InvalidAlgorithmConfiguration( + "GenericAlgorithmConfigurationError", "my message" + ) + assert error.type == "InvalidAlgorithmConfiguration" + assert error.title == "GenericAlgorithmConfigurationError" + assert error.detail == "my message" + with pytest.raises(ValueError): + str(error) == error.detail + raise error diff --git a/src/gt4sd/tests/test_s3.py b/src/gt4sd/tests/test_s3.py new file mode 100644 index 000000000..aa2fea0b1 --- /dev/null +++ b/src/gt4sd/tests/test_s3.py @@ -0,0 +1,94 @@ +"""S3 storage tests. + +These assume a bucket `gt4sd-ci-tests` and an artifact `a_folder/containing/a_file.txt`. +""" + +import logging +import os +import shutil +import tempfile + +import pytest + +from gt4sd.s3 import GT4SDS3Client, sync_folder_with_s3 +from gt4sd.tests.utils import GT4SDTestSettings + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +test_settings = GT4SDTestSettings.get_instance() + + +@pytest.fixture +def client() -> GT4SDS3Client: + return GT4SDS3Client( + host=test_settings.gt4sd_s3_host, + access_key=test_settings.gt4sd_s3_access_key, + secret_key=test_settings.gt4sd_s3_secret_key, + secure=test_settings.gt4sd_s3_secure, + ) + + +def test_bucket_listing(client: GT4SDS3Client): + bucket_names = client.list_bucket_names() + assert isinstance(bucket_names, list) + assert "gt4sd-ci-tests" in bucket_names + + +def test_object_listing(client: GT4SDS3Client): + object_names = client.list_object_names(bucket="gt4sd-ci-tests", prefix="a_folder") + assert isinstance(object_names, list) + for object_name in object_names: + assert object_name.endswith(".txt") + + +def test_directory_listing(client: GT4SDS3Client): + object_names = client.list_directories(bucket="gt4sd-ci-tests", prefix="a_folder") + assert isinstance(object_names, set) + for object_name in object_names: + assert object_name == "containing" + + +def test_sync_folder(client: GT4SDS3Client): + directory = tempfile.mkdtemp() + try: + client.sync_folder("gt4sd-ci-tests", directory) + filepaths = set( + [filepath.replace(directory, "") for filepath in os.listdir(directory)] + ) + objectpaths = set( + map( + lambda path: path.split("/")[0], + client.list_object_names(bucket="gt4sd-ci-tests"), + ) + ) + assert filepaths == objectpaths + finally: + logger.info(f"cleaning up test folder {directory}") + shutil.rmtree(directory) + + +def test_sync_folder_with_s3(client): + directory = tempfile.mkdtemp() + try: + sync_folder_with_s3( + host=test_settings.gt4sd_s3_host, + access_key=test_settings.gt4sd_s3_access_key, + secret_key=test_settings.gt4sd_s3_secret_key, + bucket="gt4sd-ci-tests", + folder_path=directory, + secure=test_settings.gt4sd_s3_secure, + ) + filepaths = set( + [filepath.replace(directory, "") for filepath in os.listdir(directory)] + ) + objectpaths = set( + map( + lambda path: path.split("/")[0], + client.list_object_names(bucket="gt4sd-ci-tests"), + ) + ) + assert filepaths == objectpaths + finally: + logger.info(f"cleaning up test folder {directory}") + shutil.rmtree(directory) diff --git a/src/gt4sd/tests/utils.py b/src/gt4sd/tests/utils.py new file mode 100644 index 000000000..82c6eddc4 --- /dev/null +++ b/src/gt4sd/tests/utils.py @@ -0,0 +1,24 @@ +"""Utilities used in the tests.""" + +from functools import lru_cache + +from pydantic import BaseSettings + + +class GT4SDTestSettings(BaseSettings): + """Utility variables for the tests setup.""" + + gt4sd_s3_host: str = "localhost:9000" + gt4sd_s3_access_key: str = "access-key" + gt4sd_s3_secret_key: str = "secret-key" + gt4sd_s3_secure: bool = False + gt4sd_ci: bool = False + + class Config: + # immutable and in turn hashable, that is required for lru_cache + frozen = True + + @staticmethod + @lru_cache(maxsize=None) + def get_instance() -> "GT4SDTestSettings": + return GT4SDTestSettings() diff --git a/src/gt4sd/training_pipelines/__init__.py b/src/gt4sd/training_pipelines/__init__.py new file mode 100644 index 000000000..8405fcf15 --- /dev/null +++ b/src/gt4sd/training_pipelines/__init__.py @@ -0,0 +1,95 @@ +"""Module initialization for gt4sd traning pipelines.""" + +import json +import logging +from typing import Any, Dict + +import pkg_resources + +from ..cli.load_arguments_from_dataclass import extract_fields_from_class +from .paccmann.core import PaccMannDataArguments, PaccMannTrainingArguments +from .paccmann.vae.core import PaccMannVAEModelArguments, PaccMannVAETrainingPipeline +from .pytorch_lightning.core import PytorchLightningTrainingArguments +from .pytorch_lightning.granular.core import ( + GranularDataArguments, + GranularModelArguments, + GranularTrainingPipeline, +) +from .pytorch_lightning.language_modeling.core import ( + LanguageModelingDataArguments, + LanguageModelingModelArguments, + LanguageModelingTrainingPipeline, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + +TRAINING_PIPELINE_NAME_METADATA_MAPPING = { + "mock_training_pipeline": "mock_training_pipeline.json", + "Terminator training": "terminator_training.json", +} + +TRAINING_PIPELINE_ARGUMENTS_MAPPING = { + "language-modeling-trainer": ( + PytorchLightningTrainingArguments, + LanguageModelingDataArguments, + LanguageModelingModelArguments, + ), + "paccmann-vae-trainer": ( + PaccMannTrainingArguments, + PaccMannDataArguments, + PaccMannVAEModelArguments, + ), + "granular-trainer": ( + PytorchLightningTrainingArguments, + GranularDataArguments, + GranularModelArguments, + ), +} + +TRAINING_PIPELINE_MAPPING = { + "language-modeling-trainer": LanguageModelingTrainingPipeline, + "paccmann-vae-trainer": PaccMannVAETrainingPipeline, + "granular-trainer": GranularTrainingPipeline, +} + + +def training_pipeline_name_to_metadata(name: str) -> Dict[str, Any]: + """Retrieve training pipeline metadata from the name. + + Args: + name: name of the pipeline. + + Returns: + dictionary describing the parameters of the pipeline. If the pipeline is not found, no metadata (a.k.a., an empty dictionary is returned). + """ + metadata: Dict[str, Any] = {"training_pipeline": name, "parameters": {}} + if name in TRAINING_PIPELINE_NAME_METADATA_MAPPING: + try: + with open( + pkg_resources.resource_filename( + "gt4sd", + f"training_pipelines/{TRAINING_PIPELINE_NAME_METADATA_MAPPING[name]}", + ), + "rt", + ) as fp: + metadata["parameters"] = json.load(fp) + except Exception: + logger.exception( + f'training pipeline "{name}" metadata fetching failed, returning an empty metadata dictionary' + ) + + elif name in TRAINING_PIPELINE_ARGUMENTS_MAPPING: + + for training_argument_class in TRAINING_PIPELINE_ARGUMENTS_MAPPING[name]: + field_types = extract_fields_from_class(training_argument_class) + metadata["parameters"].update(field_types) + + else: + logger.warning( + f'training pipeline "{name}" metadata not found, returning an empty metadata dictionary' + ) + metadata["description"] = metadata["parameters"].pop( + "description", "A training pipeline." + ) + return metadata diff --git a/src/gt4sd/training_pipelines/core.py b/src/gt4sd/training_pipelines/core.py new file mode 100644 index 000000000..9c2f09152 --- /dev/null +++ b/src/gt4sd/training_pipelines/core.py @@ -0,0 +1,18 @@ +"""Core training utilities.""" + +from dataclasses import dataclass + + +class TrainingPipeline: + """Abstract interface for a training pipelines.""" + + def train(self, **kwargs) -> None: + """Train the models associated to a pipeline.""" + raise NotImplementedError("Can't train an abstract training pipeline.") + + +@dataclass +class TrainingPipelineArguments: + """Abstract interface for training pipeline arguments.""" + + __name__ = "training_args" diff --git a/src/gt4sd/training_pipelines/mock_training_pipeline.json b/src/gt4sd/training_pipelines/mock_training_pipeline.json new file mode 100644 index 000000000..9e7fd1ec6 --- /dev/null +++ b/src/gt4sd/training_pipelines/mock_training_pipeline.json @@ -0,0 +1,17 @@ +{ + "description": "Pipeline mocking a training process lasting 30 seconds.", + "batch_size": { + "type": "integer", + "default": 32, + "description": "Batch size for training.", + "example": 32, + "optional": false + }, + "epochs": { + "type": "integer", + "default": 8, + "description": "The batch size per GPU core/CPU for training.", + "example": 8, + "optional": false + } +} \ No newline at end of file diff --git a/src/gt4sd/training_pipelines/paccmann/__init__.py b/src/gt4sd/training_pipelines/paccmann/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/paccmann/core.py b/src/gt4sd/training_pipelines/paccmann/core.py new file mode 100644 index 000000000..cd49670c9 --- /dev/null +++ b/src/gt4sd/training_pipelines/paccmann/core.py @@ -0,0 +1,105 @@ +"""PaccMann training utilities.""" + +from dataclasses import dataclass, field +from typing import Any, Dict + +from ..core import TrainingPipeline, TrainingPipelineArguments + + +class PaccMannTrainingPipeline(TrainingPipeline): + """PyTorch lightining training pipelines.""" + + def train( # type: ignore + self, + training_args: Dict[str, Any], + model_args: Dict[str, Any], + dataset_args: Dict[str, Any], + ) -> None: + """Generic training function for PaccMann training. + + Args: + training_args: training arguments passed to the configuration. + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + + Raises: + NotImplementedError: the generic trainer does not implement the pipeline. + """ + raise NotImplementedError + + +@dataclass +class PaccMannTrainingArguments(TrainingPipelineArguments): + """Arguments related to PaccMann trainer.""" + + __name__ = "training_args" + + model_path: str = field(metadata={"help": "Path where the model artifacts."}) + training_name: str = field(metadata={"help": "Name used to identify the training."}) + epochs: int = field(default=50, metadata={"help": "Number of epochs."}) + batch_size: int = field(default=256, metadata={"help": "Size of the batch."}) + learning_rate: float = field( + default=0.0005, metadata={"help": "Learning rate used in training."} + ) + optimizer: str = field( + default="adam", metadata={"help": "Optimizer used during training."} + ) + log_interval: int = field( + default=100, metadata={"help": "Number of steps between log intervals."} + ) + save_interval: int = field( + default=1000, metadata={"help": "Number of steps between model save intervals."} + ) + eval_interval: int = field( + default=500, metadata={"help": "Number of steps between evaluation intervals."} + ) + + +@dataclass +class PaccMannDataArguments(TrainingPipelineArguments): + """Arguments related to PaccMann data loading.""" + + __name__ = "dataset_args" + + train_smiles_filepath: str = field( + metadata={"help": "Training file containing SMILES in .smi format."} + ) + test_smiles_filepath: str = field( + metadata={"help": "Testing file containing SMILES in .smi format."} + ) + smiles_language_filepath: str = field( + default="none", metadata={"help": "Optional SMILES language file."} + ) + add_start_stop_token: bool = field( + default=True, metadata={"help": "Whether start and stop token should be added."} + ) + selfies: bool = field( + default=True, metadata={"help": "Whether SELFIES representations are used."} + ) + num_workers: int = field( + default=0, metadata={"help": "Number of workers used in data loading."} + ) + pin_memory: bool = field( + default=False, metadata={"help": "Whether memory in the data loader is pinned."} + ) + augment_smiles: bool = field( + default=False, metadata={"help": "Whether SMILES augumentation is used."} + ) + canonical: bool = field( + default=False, metadata={"help": "Whether SMILES canonicalization is used."} + ) + kekulize: bool = field( + default=False, metadata={"help": "Whether SMILES kekulization is used."} + ) + all_bonds_explicit: bool = field( + default=False, metadata={"help": "Whether all bonds are explicit."} + ) + all_hs_explicit: bool = field( + default=False, metadata={"help": "Whether all hydrogens are explicit."} + ) + remove_bonddir: bool = field( + default=False, metadata={"help": "Remove bond directionality."} + ) + remove_chirality: bool = field( + default=False, metadata={"help": "Remove chirality."} + ) diff --git a/src/gt4sd/training_pipelines/paccmann/vae/__init__.py b/src/gt4sd/training_pipelines/paccmann/vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/paccmann/vae/core.py b/src/gt4sd/training_pipelines/paccmann/vae/core.py new file mode 100644 index 000000000..f33795808 --- /dev/null +++ b/src/gt4sd/training_pipelines/paccmann/vae/core.py @@ -0,0 +1,309 @@ +"""PaccMann VAE training utilities.""" + +import json +import logging +import os +from dataclasses import dataclass, field +from time import time +from typing import Any, Dict, Optional, cast + +import torch +from paccmann_chemistry.models.training import train_vae +from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE +from paccmann_chemistry.utils import collate_fn, disable_rdkit_logging +from paccmann_chemistry.utils.hyperparams import SEARCH_FACTORY +from pytoda.datasets import SMILESDataset +from pytoda.smiles.smiles_language import SMILESLanguage +from torch.utils.tensorboard import SummaryWriter + +from ....frameworks.torch import get_device +from ...core import TrainingPipelineArguments +from ..core import PaccMannTrainingPipeline + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PaccMannVAETrainingPipeline(PaccMannTrainingPipeline): + """Language modeling training pipelines.""" + + def train( # type: ignore + self, + training_args: Dict[str, Any], + model_args: Dict[str, Any], + dataset_args: Dict[str, Any], + ) -> None: + """Generic training function for PaccMann training. + + Args: + training_args: training arguments passed to the configuration. + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + """ + try: + device = get_device() + disable_rdkit_logging() + params = {**training_args, **dataset_args, **model_args} + train_smiles_filepath = params["train_smiles_filepath"] + test_smiles_filepath = params["test_smiles_filepath"] + smiles_language_filepath = ( + params["smiles_language_filepath"] + if params.get("smiles_language_filepath", "none").lower() != "none" + else None + ) + + model_path = params["model_path"] + training_name = params["training_name"] + + writer = SummaryWriter(f"logs/{training_name}") + logger.info(f"Model with name {training_name} starts.") + + model_dir = os.path.join(model_path, training_name) + log_path = os.path.join(model_dir, "logs") + val_dir = os.path.join(log_path, "val_logs") + os.makedirs(os.path.join(model_dir, "weights"), exist_ok=True) + os.makedirs(os.path.join(model_dir, "results"), exist_ok=True) + os.makedirs(log_path, exist_ok=True) + os.makedirs(val_dir, exist_ok=True) + + # Load SMILES language + smiles_language: Optional[SMILESLanguage] = None + if smiles_language_filepath is not None: + smiles_language = SMILESLanguage.load(smiles_language_filepath) + + logger.info(f"Smiles filepath: {train_smiles_filepath}") + + # create SMILES eager dataset + smiles_train_data = SMILESDataset( + train_smiles_filepath, + smiles_language=smiles_language, + padding=False, + selfies=params.get("selfies", False), + add_start_and_stop=params.get("add_start_stop_token", True), + augment=params.get("augment_smiles", False), + canonical=params.get("canonical", False), + kekulize=params.get("kekulize", False), + all_bonds_explicit=params.get("all_bonds_explicit", False), + all_hs_explicit=params.get("all_hs_explicit", False), + remove_bonddir=params.get("remove_bonddir", False), + remove_chirality=params.get("remove_chirality", False), + backend="lazy", + device=device, + ) + smiles_test_data = SMILESDataset( + test_smiles_filepath, + smiles_language=smiles_language, + padding=False, + selfies=params.get("selfies", False), + add_start_and_stop=params.get("add_start_stop_token", True), + augment=params.get("augment_smiles", False), + canonical=params.get("canonical", False), + kekulize=params.get("kekulize", False), + all_bonds_explicit=params.get("all_bonds_explicit", False), + all_hs_explicit=params.get("all_hs_explicit", False), + remove_bonddir=params.get("remove_bonddir", False), + remove_chirality=params.get("remove_chirality", False), + backend="lazy", + device=device, + ) + + if smiles_language_filepath is None: + smiles_language = smiles_train_data.smiles_language + smiles_language.save(os.path.join(model_path, f"{training_name}.lang")) + else: + smiles_language_filename = os.path.basename(smiles_language_filepath) + cast(SMILESLanguage, smiles_language).save( + os.path.join(model_dir, smiles_language_filename) + ) + + params.update( + { + "vocab_size": cast( + SMILESLanguage, smiles_language + ).number_of_tokens, + "pad_index": cast(SMILESLanguage, smiles_language).padding_index, + } + ) + + vocab_dict = cast(SMILESLanguage, smiles_language).index_to_token + params.update( + { + "start_index": list(vocab_dict.keys())[ + list(vocab_dict.values()).index("") + ], + "end_index": list(vocab_dict.keys())[ + list(vocab_dict.values()).index("") + ], + } + ) + + if params.get("embedding", "learned") == "one_hot": + params.update({"embedding_size": params["vocab_size"]}) + + with open(os.path.join(model_dir, "model_params.json"), "w") as fp: + json.dump(params, fp) + + # create DataLoaders + train_data_loader = torch.utils.data.DataLoader( + smiles_train_data, + batch_size=params.get("batch_size", 64), + collate_fn=collate_fn, + drop_last=True, + shuffle=True, + pin_memory=params.get("pin_memory", True), + num_workers=params.get("num_workers", 8), + ) + + test_data_loader = torch.utils.data.DataLoader( + smiles_test_data, + batch_size=params.get("batch_size", 64), + collate_fn=collate_fn, + drop_last=True, + shuffle=True, + pin_memory=params.get("pin_memory", True), + num_workers=params.get("num_workers", 8), + ) + # initialize encoder and decoder + gru_encoder = StackGRUEncoder(params).to(device) + gru_decoder = StackGRUDecoder(params).to(device) + gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) + logger.info("Model summary:") + for name, parameter in gru_vae.named_parameters(): + logger.info(f"Param {name}, shape:\t{parameter.shape}") + total_params = sum(p.numel() for p in gru_vae.parameters()) + logger.info(f"Total # params: {total_params}") + + loss_tracker = { + "test_loss_a": 10e4, + "test_rec_a": 10e4, + "test_kld_a": 10e4, + "ep_loss": 0, + "ep_rec": 0, + "ep_kld": 0, + } + + # train for n_epoch epochs + logger.info("Model creation and data processing done, Training starts.") + decoder_search = SEARCH_FACTORY[params.get("decoder_search", "sampling")]( + temperature=params.get("temperature", 1.0), + beam_width=params.get("beam_width", 3), + top_tokens=params.get("top_tokens", 5), + ) + + if writer: + pparams = params.copy() + pparams["training_file"] = train_smiles_filepath + pparams["test_file"] = test_smiles_filepath + pparams["language_file"] = smiles_language_filepath + pparams["model_path"] = model_path + pparams = {k: v if v is not None else "N.A." for k, v in params.items()} + pparams["training_name"] = training_name + from pprint import pprint + + pprint(pparams) + writer.add_hparams(hparam_dict=pparams, metric_dict={}) + + for epoch in range(params["epochs"] + 1): + t = time() + loss_tracker = train_vae( + epoch, + gru_vae, + train_data_loader, + test_data_loader, + smiles_language, + model_dir, + search=decoder_search, + optimizer=params.get("optimizer", "adadelta"), + lr=params["learning_rate"], + kl_growth=params["kl_growth"], + input_keep=params["input_keep"], + test_input_keep=params["test_input_keep"], + generate_len=params["generate_len"], + log_interval=params["log_interval"], + save_interval=params["save_interval"], + eval_interval=params["eval_interval"], + loss_tracker=loss_tracker, + logger=logger, + # writer=writer, + batch_mode=params.get("batch_mode"), + ) + logger.info(f"Epoch {epoch}, took {time() - t:.1f}.") + + logger.info( + "Overall:\tBest loss = {0:.4f} in Ep {1}, " + "best Rec = {2:.4f} in Ep {3}, " + "best KLD = {4:.4f} in Ep {5}".format( + loss_tracker["test_loss_a"], + loss_tracker["ep_loss"], + loss_tracker["test_rec_a"], + loss_tracker["ep_rec"], + loss_tracker["test_kld_a"], + loss_tracker["ep_kld"], + ) + ) + logger.info("Training done, shutting down.") + except Exception: + logger.exception( + "Exception occurred while running PaccMannVAETrainingPipeline" + ) + + +@dataclass +class PaccMannVAEModelArguments(TrainingPipelineArguments): + """Arguments pertaining to model instantiation.""" + + __name__ = "model_args" + + n_layers: int = field( + default=2, metadata={"help": "Number of layers for the RNNs."} + ) + bidirectional: bool = field( + default=False, metadata={"help": "Whether the RNN cells are bidirectional."} + ) + rnn_cell_size: int = field(default=512, metadata={"help": "Size of the RNN cells."}) + latent_dim: int = field(default=256, metadata={"help": "Size of the RNN cells."}) + stack_width: int = field( + default=50, metadata={"help": "Width of the memory stack for the RNN cell."} + ) + stack_depth: int = field( + default=50, metadata={"help": "Depth of the memory stack for the RNN cell."} + ) + decode_search: str = field( + default="sampling", metadata={"help": "Decoder search strategy."} + ) + dropout: float = field(default=0.2, metadata={"help": "Dropout rate to apply."}) + generate_len: int = field( + default=100, metadata={"help": "Length in tokens of the generated molecules."} + ) + kl_growth: float = field( + default=0.003, metadata={"help": "Growth of the KL term weight in the loss."} + ) + input_keep: float = field( + default=0.85, metadata={"help": "Probability to keep input tokens in train."} + ) + test_input_keep: float = field( + default=1.0, metadata={"help": "Probability to keep input tokens in test."} + ) + temperature: float = field( + default=0.8, metadata={"help": "Temperature for the sampling."} + ) + embedding: str = field( + default="one_hot", + metadata={ + "help": "Embedding technique for the tokens. 'one_hot' or 'learned'." + }, + ) + batch_mode: str = field( + default="packed", metadata={"help": "Batch mode. 'packed' or 'padded'."} + ) + vocab_size: int = field( + default=380, metadata={"help": "Size of the vocabulary of chemical tokens."} + ) + pad_index: int = field(default=0, metadata={"help": "Index for the padding token."}) + embedding_size: int = field( + default=380, metadata={"help": "Size of the embedding vectors."} + ) + beam_width: int = field(default=3, metadata={"help": "Width of the beam search."}) + top_tokens: int = field( + default=5, metadata={"help": "Number of tokens to consider in the beam search."} + ) diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/__init__.py b/src/gt4sd/training_pipelines/pytorch_lightning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/core.py b/src/gt4sd/training_pipelines/pytorch_lightning/core.py new file mode 100644 index 000000000..ceb57f8ae --- /dev/null +++ b/src/gt4sd/training_pipelines/pytorch_lightning/core.py @@ -0,0 +1,178 @@ +"""PyTorch Lightning training utilities.""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +from ..core import TrainingPipeline, TrainingPipelineArguments + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PyTorchLightningTrainingPipeline(TrainingPipeline): + """PyTorch lightining training pipelines.""" + + def train( # type: ignore + self, + pl_trainer_args: Dict[str, Any], + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> None: + """Generic training function for PyTorch Lightning-based training. + + Args: + pl_trainer_args: pytorch lightning trainer arguments passed to the configuration. + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + """ + + logger.info(f"Trainer arguments: {pl_trainer_args}") + + if pl_trainer_args[ + "resume_from_checkpoint" + ] is not None and not pl_trainer_args["resume_from_checkpoint"].endswith( + ".ckpt" + ): + pl_trainer_args["resume_from_checkpoint"] = None + + pl_trainer_args["callbacks"] = { + "model_checkpoint_callback": { + "monitor": pl_trainer_args["monitor"], + "save_top_k": pl_trainer_args["save_top_k"], + "mode": pl_trainer_args["mode"], + "every_n_train_steps": pl_trainer_args["every_n_train_steps"], + "save_last": pl_trainer_args["save_last"], + } + } + + del ( + pl_trainer_args["monitor"], + pl_trainer_args["save_top_k"], + pl_trainer_args["mode"], + pl_trainer_args["every_n_train_steps"], + pl_trainer_args["save_last"], + ) + + pl_trainer_args["callbacks"] = self.add_callbacks(pl_trainer_args["callbacks"]) + + trainer = Trainer(**pl_trainer_args) + data_module, model_module = self.get_data_and_model_modules( + model_args, dataset_args + ) + trainer.fit(model_module, data_module) + + def get_data_and_model_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> Tuple[LightningDataModule, LightningModule]: + """Get data and model modules for training. + + Args: + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + + Returns: + the data and model modules. + """ + raise NotImplementedError( + "Can't get data and model modules for an abstract training pipeline." + ) + + def add_callbacks(self, callback_args: Dict[str, Any]) -> List[Callback]: + """Create the requested callbacks for training. + + Args: + callback_args: callback arguments passed to the configuration. + + Returns: + list of pytorch lightning callbacks. + """ + + callbacks: List[Callback] = [] + if "early_stopping_callback" in callback_args: + callbacks.append(EarlyStopping(**callback_args["early_stopping_callback"])) + + if "model_checkpoint_callback" in callback_args: + callbacks.append( + ModelCheckpoint(**callback_args["model_checkpoint_callback"]) + ) + + return callbacks + + +@dataclass +class PytorchLightningTrainingArguments(TrainingPipelineArguments): + """ + Arguments related to pytorch lightning trainer. + """ + + __name__ = "pl_trainer_args" + + accelerator: Union[str, None] = field( + default="ddp", metadata={"help": "Accelerator type."} + ) + accumulate_grad_batches: int = field( + default=1, + metadata={ + "help": "Accumulates grads every k batches or as set up in the dict." + }, + ) + val_check_interval: int = field( + default=5000, metadata={"help": " How often to check the validation set."} + ) + default_root_dir: Union[str, None] = field( + default=None, metadata={"help": "Default path for logs and output."} + ) + + gradient_clip_val: float = field( + default=0.0, metadata={"help": "Gradient clipping value."} + ) + limit_val_batches: int = field( + default=500, metadata={"help": "How much of validation dataset to check."} + ) + log_every_n_steps: int = field( + default=500, metadata={"help": "How often to log within steps."} + ) + max_epochs: int = field( + default=3, + metadata={"help": "Stop training once this number of epochs is reached."}, + ) + resume_from_checkpoint: Union[str, None] = field( + default=None, + metadata={"help": "Path/URL of the checkpoint from which training is resumed."}, + ) + gpus: Union[int, None] = field( + default=-1, + metadata={"help": "Number of gpus to train on."}, + ) + monitor: Union[str, None] = field( + default=None, + metadata={"help": "Quantity to monitor in order to store a checkpoint."}, + ) + save_last: bool = field( + default=True, + metadata={ + "help": "When True, always saves the model at the end of the epoch to a file last.ckpt" + }, + ) + save_top_k: Optional[int] = field( + default=None, + metadata={ + "help": "The best k models according to the quantity monitored will be saved." + }, + ) + mode: str = field( + default="min", + metadata={"help": "Quantity to monitor in order to store a checkpoint."}, + ) + every_n_train_steps: Union[int, None] = field( + default=None, + metadata={"help": "Number of training steps between checkpoints."}, + ) diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/granular/__init__.py b/src/gt4sd/training_pipelines/pytorch_lightning/granular/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/granular/core.py b/src/gt4sd/training_pipelines/pytorch_lightning/granular/core.py new file mode 100644 index 000000000..8ad403646 --- /dev/null +++ b/src/gt4sd/training_pipelines/pytorch_lightning/granular/core.py @@ -0,0 +1,153 @@ +"""Granular training utilities.""" + +import json +import logging +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple + +from pytorch_lightning import LightningDataModule, LightningModule + +from ....frameworks.granular.dataloader.data_module import GranularDataModule +from ....frameworks.granular.dataloader.dataset import build_dataset_and_architecture +from ....frameworks.granular.ml.models import AUTOENCODER_ARCHITECTURES +from ....frameworks.granular.ml.module import GranularModule +from ...core import TrainingPipelineArguments +from ..core import PyTorchLightningTrainingPipeline + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class GranularTrainingPipeline(PyTorchLightningTrainingPipeline): + """Granular training pipelines.""" + + def get_data_and_model_modules( + self, + model_args: Dict[str, Any], + dataset_args: Dict[str, Any], + **kwargs, + ) -> Tuple[LightningDataModule, LightningModule]: + """Get data and model modules for training. + + Args: + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + + Returns: + the data and model modules. + """ + + configuration = {**model_args, **dataset_args} + + with open(model_args["model_list_path"], "r") as fp: # type:ignore + configuration["model_list"] = json.load(fp)["models"] + + arguments = Namespace(**configuration) + datasets = [] + architecture_autoencoders = [] + architecture_latent_models = [] + for model in arguments.model_list: + logger.info(f"dataset preparation for model={model}") + hparams = configuration["model_list"][model] + hparams["name"] = model + model_type = hparams["type"].lower() + dataset, architecture = build_dataset_and_architecture( + hparams["name"], + hparams["data_path"], + hparams["data_file"], + hparams["dataset_type"], + hparams["type"], + hparams, + ) + datasets.append(dataset) + if model_type in AUTOENCODER_ARCHITECTURES: + architecture_autoencoders.append(architecture) + else: + architecture_latent_models.append(architecture) + dm = GranularDataModule( + datasets, + batch_size=arguments.batch_size, + validation_split=arguments.validation_split, + validation_indices_file=arguments.validation_indices_file, + stratified_batch_file=arguments.stratified_batch_file, + stratified_value_name=arguments.stratified_value_name, + num_workers=arguments.num_workers, + ) + dm.prepare_data() + module = GranularModule( + architecture_autoencoders=architecture_autoencoders, + architecture_latent_models=architecture_latent_models, + lr=arguments.lr, + test_output_path=arguments.test_output_path, + ) + + return dm, module + + +@dataclass +class GranularModelArguments(TrainingPipelineArguments): + """ + Arguments related to model. + """ + + __name__ = "model_args" + + model_list_path: str = field( + metadata={ + "help": "Path to a json file that contains a dictionary with models and their parameters." + }, + ) + lr: float = field( + default=0.0001, + metadata={"help": "The learning rate."}, + ) + + test_output_path: Optional[str] = field( + default="./test", + metadata={ + "help": "Path where to save latent encodings and predictions for the test set when an epoch ends. Defaults to a a folder called 'test' in the current working directory." + }, + ) + + +@dataclass +class GranularDataArguments(TrainingPipelineArguments): + """ + Arguments related to data. + """ + + __name__ = "dataset_args" + + batch_size: int = field( + default=64, + metadata={"help": "Batch size of the training. Defaults to 64."}, + ) + validation_split: Optional[float] = field( + default=None, + metadata={ + "help": "Proportion used for validation. Defaults to None, a.k.a., use indices file if provided otherwise uses half of the data for validation." + }, + ) + validation_indices_file: Optional[str] = field( + default=None, + metadata={ + "help": "Indices to use for validation. Defaults to None, a.k.a., use validation split proportion, if not provided uses half of the data for validation." + }, + ) + stratified_batch_file: Optional[str] = field( + default=None, + metadata={ + "help": "Stratified batch file for sampling. Defaults to None, a.k.a., no stratified sampling." + }, + ) + stratified_value_name: Optional[str] = field( + default=None, + metadata={ + "help": "Stratified value name. Defaults to None, a.k.a., no stratified sampling. Needed in case a stratified batch file is provided." + }, + ) + num_workers: int = field( + default=1, + metadata={"help": "number of workers. Defaults to 1."}, + ) diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/__init__.py b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/core.py b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/core.py new file mode 100644 index 000000000..0f05bb2d1 --- /dev/null +++ b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/core.py @@ -0,0 +1,248 @@ +"""Language modeling training utilities.""" + +import logging +from dataclasses import dataclass, field +from typing import Dict, Optional, Tuple, Union + +from pytorch_lightning import LightningDataModule, LightningModule + +from ...core import TrainingPipelineArguments +from ..core import PyTorchLightningTrainingPipeline +from .lm_datasets import CGMDataModule, CLMDataModule, MLMDataModule, PLMDataModule +from .models import CGMModule, CLMModule, MLMModule, PLMModule + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class LanguageModelingTrainingPipeline(PyTorchLightningTrainingPipeline): + """Language modeling training pipelines.""" + + def get_data_and_model_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + **kwargs, + ) -> Tuple[LightningDataModule, LightningModule]: + """Get data and model modules for training. + + Args: + model_args: model arguments passed to the configuration. + dataset_args: dataset arguments passed to the configuration. + + Returns: + the data and model modules. + """ + + if ( + model_args["model_config_name"] is None + and model_args["model_name_or_path"] is None + ): + raise ValueError("Model config or model name/path should be provided") + + if ( + model_args["model_config_name"] is not None + and model_args["model_name_or_path"] is not None + ): + logger.warning( + "Config name is omitted. Start fine-tuning using {}".format( + model_args["model_name_or_path"] + ) + ) + + if model_args["tokenizer"] is None: + + if model_args["model_name_or_path"] is not None: + model_args["tokenizer"] = model_args["model_name_or_path"] + else: + model_args["tokenizer"] = model_args["model_config_name"] + logger.warning( + "{} tokenizer is going to be used in the training".format( + model_args["tokenizer"] + ) + ) + + logger.info(f"Model arguments: {model_args}") + logger.info(f"Dataset arguments: {dataset_args}") + + if model_args["type"] == "mlm": + data_module, model_module = self.get_mlm_modules(model_args, dataset_args) + elif model_args["type"] == "clm": + data_module, model_module = self.get_clm_modules(model_args, dataset_args) # type: ignore + elif model_args["type"] == "plm": + data_module, model_module = self.get_plm_modules(model_args, dataset_args) # type: ignore + elif model_args["type"] == "cgm": + data_module, model_module = self.get_cgm_modules(model_args, dataset_args) # type: ignore + else: + raise ValueError(f"LM training type {model_args['type']} not supported") + + model_module.model.resize_token_embeddings(len(data_module.tokenizer)) # type: ignore + + return data_module, model_module + + def get_mlm_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> Tuple[MLMDataModule, MLMModule]: + """Get model and data module for clm. + + Args: + model_args: dictionary containing all the parameters for the mode configuration. + dataset_args: dictionary containing all the necessary parameters for the dataset creation. + Returns: + model and data module for clm. + """ + + model_module = MLMModule(model_args) + data_module = MLMDataModule(dataset_args, tokenizer=model_module.tokenizer) + + return data_module, model_module + + def get_clm_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> Tuple[CLMDataModule, CLMModule]: + """Get model and data module for clm. + + Args: + model_args: dictionary containing all the parameters for the mode configuration. + dataset_args: dictionary containing all the necessary parameters for the dataset creation. + Returns: + model and data module for clm. + """ + + model_module = CLMModule(model_args) + data_module = CLMDataModule(dataset_args, tokenizer=model_module.tokenizer) + + return data_module, model_module + + def get_plm_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> Tuple[PLMDataModule, PLMModule]: + """Get model and data module for plm. + + Args: + model_args: dictionary containing all the parameters for the mode configuration. + dataset_args: dictionary containing all the necessary parameters for the dataset creation. + Returns: + model and data module for plm. + """ + + model_module = PLMModule(model_args) + data_module = PLMDataModule(dataset_args, tokenizer=model_module.tokenizer) + + return data_module, model_module + + def get_cgm_modules( + self, + model_args: Dict[str, Union[float, str, int]], + dataset_args: Dict[str, Union[float, str, int]], + ) -> Tuple[CGMDataModule, CGMModule]: + """Get model and data module for Conditional Generation model. + + Args: + model_args: dictionary containing all the parameters for the mode configuration. + dataset_args: dictionary containing all the necessary parameters for the dataset creation. + Returns: + model and data module for plm. + """ + + model_module = CGMModule(model_args) + data_module = CGMDataModule(dataset_args, tokenizer=model_module.tokenizer) + + return data_module, model_module + + +@dataclass +class LanguageModelingModelArguments(TrainingPipelineArguments): + """ + Arguments pertaining to which model/config we are going to fine-tune, or train from scratch. + """ + + __name__ = "model_args" + + type: str = field( + metadata={"help": "The language modeling type, for example mlm."}, + ) + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization, for example bert-base-uncased." + }, + ) + model_config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path."}, + ) + tokenizer: Optional[str] = field( + default=None, + metadata={ + "help": "Name of the tokenizer to be used, default: tokenizer of utilizing model." + }, + ) + lr: float = field( + default=2e-5, + metadata={"help": "The learning rate."}, + ) + lr_decay: float = field( + default=0.5, + metadata={"help": "The learning rate decay."}, + ) + cache_dir: Union[str, None] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co." + }, + ) + + +@dataclass +class LanguageModelingDataArguments(TrainingPipelineArguments): + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + __name__ = "dataset_args" + + train_file: str = field( + metadata={ + "help": "The input training data file (a text file), for example path/to/file." + } + ) + validation_file: str = field( + metadata={ + "help": "The input evaluation data file to evaluate the perplexity on (a text file), for example path/to/file." + }, + ) + max_length: int = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + }, + ) + mlm_probability: float = field( + default=0.15, + metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}, + ) + plm_probability: float = field( + default=0.16666, + metadata={ + "help": "Ratio of length of a span of masked tokens to surrounding context length for " + "permutation language modeling." + }, + ) + max_span_length: int = field( + default=5, + metadata={ + "help": "Maximum length of a span of masked tokens for permutation language modeling." + }, + ) + batch_size: int = field( + default=8, + metadata={"help": "Ratio of tokens to mask for masked language modeling loss."}, + ) diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/lm_datasets.py b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/lm_datasets.py new file mode 100644 index 000000000..126aff99a --- /dev/null +++ b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/lm_datasets.py @@ -0,0 +1,344 @@ +"""Dataset routines-filtering, dataset building.""" + +import json +import os +from functools import lru_cache +from typing import Any, Callable, Dict, List, Union + +import pytorch_lightning as pl +from datasets import DatasetDict +from loguru import logger +from torch.utils.data import ConcatDataset, DataLoader, Dataset +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, + default_data_collator, +) +from transformers.tokenization_utils_base import BatchEncoding + + +class LMDataset(Dataset): + """LM dataset class.""" + + def __init__( + self, + filepath: str, + tokenizer: Callable, + ) -> None: + """Initialize the LM data module. + + Args: + filepath: path where the dataset is located. + tokenizer: tokenize function to be used in the module. + """ + + self.filepath = filepath + self.tokenizer = tokenizer + self.length = LMDataset.count_examples(filepath) + + if not self.filepath.endswith(".jsonl") and not self.filepath.endswith(".json"): + raise ValueError(f"{filepath} is not a .jsonl or a json.") + + @lru_cache() + def examples_reader(self) -> List[Dict[str, str]]: + """Read instances from a filepath. + + Returns: + list of instances. + """ + with open(self.filepath) as fp: + return [json.loads(line.strip()) for line in fp] + + @staticmethod + def count_examples(filepath: str) -> int: + """Count instances of a filepath. + + Args: + filepath: path of the dataset. + Returns: + number of examples existed in the given filepath. + """ + + def _make_gen(reader): + while True: + b = reader(2 ** 16) + if not b: + break + yield b + + with open(filepath, "rb") as f: + count = sum(buf.count(b"\n") for buf in _make_gen(f.raw.read)) # type: ignore + return count + + def __len__(self) -> int: + """Number of instances of the dataset. + + Returns: + number of instances + """ + return self.length + + def __getitem__(self, index) -> BatchEncoding: + """Get an item of the dataset. + + Args: + index: index of the item. + Returns: + tokenized item. + """ + + examples = self.examples_reader() + example = self.tokenizer(examples[index]) + + return example + + +class DataModule(pl.LightningDataModule): + """Pytorch-lightning-style data module for LM dataset.""" + + def __init__(self, dataset_args: Dict[str, Any], tokenizer: AutoTokenizer) -> None: + """Initialize the data module. + + Args: + dataset_args: dictionary containing the arguments for the lightning data module creation. + tokenizer: tokenizer to be used in the module. + """ + + super().__init__() + + self.dataset: DatasetDict + + self.dataset_args = dataset_args + + self.tokenizer = tokenizer + + self.data_collator = default_data_collator + + if "num_dataloader_workers" not in self.dataset_args: + + self.dataset_args["num_dataloader_workers"] = 8 + + cpus_count = os.cpu_count() + if cpus_count is not None: + self.dataset_args["num_dataloader_workers"] = min( + self.dataset_args["num_dataloader_workers"], cpus_count + ) + + def build_dataset(self, path: str) -> Dataset: + """ + Build the dataset. + + Args: + path: path where the dataset is located. + Returns: + a torch Dataset. + """ + + if path.endswith(".jsonl") or path.endswith(".json"): + return LMDataset(path, self.tokenize_function) + elif os.path.isdir(path): + return ConcatDataset( + datasets=[ + LMDataset(os.path.join(path, filename), self.tokenize_function) + for filename in os.listdir(path) + if filename.endswith(".jsonl") or filename.endswith(".json") + ] + ) + else: + raise TypeError(f"{path} type is not supported for dataset") + + def tokenize_function( + self, examples: Dict[str, Union[int, slice]] + ) -> BatchEncoding: + """Tokenize the given examples. + + Args: + examples: list of examples. + Returns: + tokenized examples. + """ + + truncation = self.dataset_args.get("truncation", True) + padding = self.dataset_args.get("padding", "max_length") + max_length = self.dataset_args.get("max_length", 512) + + return self.tokenizer( # type: ignore + examples["text"], + truncation=truncation, + padding=padding, + max_length=max_length, + ) + + def load(self) -> None: + """Load datasets from the given files.""" + + self.datasets = { + "train": self.build_dataset(self.dataset_args["train_file"]), + "validation": self.build_dataset(self.dataset_args["validation_file"]), + } + + logger.info( + f"Training set size: {len(self.datasets['train'])} - Validation set size: {len(self.datasets['validation'])}" # type: ignore + ) + + def train_dataloader(self) -> DataLoader: + """Create the DataLoader for the traning step. + + Returns: + pytorch-like dataloader. + """ + return DataLoader( + self.datasets["train"], # type: ignore + batch_size=self.dataset_args["batch_size"], + num_workers=self.dataset_args["num_dataloader_workers"], + collate_fn=self.data_collator, + ) + + def val_dataloader(self) -> DataLoader: + """Create the DataLoader for the traning step. + + Returns: + pytorch-like dataloader. + """ + return DataLoader( + self.datasets["validation"], # type: ignore + batch_size=self.dataset_args["batch_size"], + num_workers=self.dataset_args["num_dataloader_workers"], + collate_fn=self.data_collator, + ) + + +class MLMDataModule(DataModule): + """Pytorch-lightning-style data module for MLM dataset.""" + + def __init__( + self, dataset_args: Dict[str, Union[float, str, int]], tokenizer: AutoTokenizer + ) -> None: + """Initialize the data module. + + Args: + dataset_args: dictionary containing the metadata for the lightning data module creation. + tokenizer: tokenizer to be used in the module. + """ + super().__init__(dataset_args, tokenizer) + + self.data_collator = DataCollatorForLanguageModeling( + self.tokenizer, self.dataset_args["mlm_probability"] # type: ignore + ) + + self.load() + + +class CGMDataModule(DataModule): + """Pytorch-lightning-style data module for conditional generation dataset.""" + + def __init__( + self, dataset_args: Dict[str, Union[float, str, int]], tokenizer: AutoTokenizer + ) -> None: + """ + Initialize the data module. + + Args: + dataset_args: dictionary containing the metadata for the lightning data module creation. + tokenizer: tokenizer to be used in the module. + """ + super().__init__(dataset_args, tokenizer) + + self.load() + + def tokenize_function( + self, examples: Dict[str, Union[int, slice]] + ) -> BatchEncoding: + """Tokenize the given examples. + + Args: + examples: list of examples. + Returns: + tokenized examples. + """ + + truncation = self.dataset_args.get("truncation", True) + padding = self.dataset_args.get("padding", "max_length") + max_length = self.dataset_args.get("max_length", 512) + + source = self.tokenizer( # type: ignore + examples["source"], + truncation=truncation, + padding=padding, + max_length=max_length, + ) + + targets = self.tokenizer( # type: ignore + examples["target"], + truncation=truncation, + padding=padding, + max_length=max_length, + ) + + return BatchEncoding( + data={ + "input_ids": source["input_ids"], + "attention_mask": source["attention_mask"], + "labels": targets["input_ids"], + "decoder_attention_mask": targets["attention_mask"], + } + ) + + +class CLMDataModule(DataModule): + """Pytorch-lightning-style data module for CLM dataset.""" + + def __init__( + self, dataset_args: Dict[str, Union[float, str, int]], tokenizer: AutoTokenizer + ) -> None: + """ + Initialize the data module. + + Args: + dataset_args: dictionary containing the metadata for the lightning data module creation. + tokenizer: tokenizer to be used in the module. + """ + super().__init__(dataset_args, tokenizer) + + self.load() + + def tokenize_function( + self, examples: Dict[str, Union[int, slice]] + ) -> BatchEncoding: + """Tokenize the given examples. + + Args: + examples: list of examples. + Returns: + tokenized examples. + """ + + tokenized_data = super().tokenize_function(examples) + + tokenized_data["labels"] = tokenized_data["input_ids"].copy() + + return tokenized_data + + +class PLMDataModule(DataModule): + """Pytorch-lightning-style data module for PLM dataset.""" + + def __init__( + self, dataset_args: Dict[str, Union[float, str, int]], tokenizer: AutoTokenizer + ) -> None: + """Initialize the data module. + + Args: + dataset_args: dictionary containing the metadata for the lightning data module creation. + tokenizer: tokenizer to be used in the module. + """ + super().__init__(dataset_args, tokenizer) + + self.data_collator = DataCollatorForPermutationLanguageModeling( + tokenizer=self.tokenizer, # type: ignore + plm_probability=self.dataset_args["plm_probability"], + max_span_length=self.dataset_args["max_span_length"], + ) + + self.load() diff --git a/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/models.py b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/models.py new file mode 100644 index 000000000..c9bbefc80 --- /dev/null +++ b/src/gt4sd/training_pipelines/pytorch_lightning/language_modeling/models.py @@ -0,0 +1,271 @@ +"""Model for Language Modeling.""" + +import logging +from typing import Dict, Type, Union + +import pytorch_lightning as pl +import torch.optim as optim +from torch import Tensor +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoTokenizer, + BartConfig, + BartForConditionalGeneration, + T5Config, + T5ForConditionalGeneration, + XLNetLMHeadModel, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class LMModule(pl.LightningModule): + """Pytorch lightning model for LM training.""" + + def __init__( + self, + model_args: Dict[str, Union[float, int, str]], + ) -> None: + """Construct an LM lightning module. + + Args: + model_args: model's arguments. + """ + super().__init__() + + self.model_args = model_args + + self.model: AutoModel + self.tokenizer: AutoTokenizer + + self.cache_dir = None + if "cache_dir" in model_args: + self.cache_dir = model_args["cache_dir"] + + self.init_model() + + def init_model(self) -> None: + """Initialize an AutoModel.""" + + if self.model_args["model_name_or_path"] is not None: + self.model = AutoModel.from_pretrained( + self.model_args["model_name_or_path"], + cache_dir=self.cache_dir, + ) + else: + config = AutoConfig.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = AutoModel.from_config(config) + + logger.info("Training from scratch") + + def forward(self, x: Tensor) -> Tensor: # type: ignore + """Forward pass on Transformer model. + + Args: + x: tensor of shape (batch_size, seq_length) containing the input_ids. + Returns: + logits of the model. + """ + return self.model(x).logits # type:ignore + + def configure_optimizers( + self, + ) -> Dict[str, object]: + """Create and return the optimizer. + + Returns: + output (dict of str: Any): + - optimizer: the optimizer used to update the parameter. + - ls_scheduler: the scheduler used to reduce the learning rate in every epoch. + - monitor: the metric that the scheduler will track over the training. + """ + + if not isinstance(self.model_args["lr"], float): + raise ValueError("Learning rate should be float") + + if not isinstance(self.model_args["lr_decay"], float): + raise ValueError("Learning rate decay rate should be float") + + optimizer = optim.AdamW(self.parameters(), lr=self.model_args["lr"]) + + scheduler = optim.lr_scheduler.StepLR(optimizer, 1, self.model_args["lr_decay"]) + + output = { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "val_loss", + } + return output + + def training_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: # type: ignore + """ + Training step which encompasses the forward pass and the computation of the loss value. + + Args: + batch: dictionary containing the input_ids and optionally the token_type_ids and the attention_type. + batch_idx: index of the current batch, unused. + Returns: + loss computed on the batch. + """ + loss = self.model(**batch).loss # type:ignore + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Tensor: # type: ignore + """ + Validation step which encompasses the forward pass and the computation of the loss value. + + Args: + batch: dictionary containing the input_ids and optionally the token_type_ids and the attention_type. + batch_idx: index of the current batch, unused. + Returns: + loss computed on the batch. + """ + loss = self.model(**batch).loss # type:ignore + self.log("val_loss", loss) + return loss + + +class MLMModule(LMModule): + """Pytorch lightning model for MLM training.""" + + def init_model(self) -> None: + """Initialize a MLM model.""" + + if self.model_args["model_name_or_path"] is not None: + self.model = AutoModelForMaskedLM.from_pretrained( + self.model_args["model_name_or_path"], cache_dir=self.cache_dir + ) + else: + config = AutoConfig.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = AutoModelForMaskedLM.from_config(config) + + logger.info("Training from scratch") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_args["tokenizer"], use_fast=False + ) + + self.model.resize_token_embeddings(len(self.tokenizer)) # type: ignore + + +class CGMModule(LMModule): + """Pytorch lightning model for conditional generation training.""" + + def init_model(self) -> None: + """Initialize a model for conditional generation.""" + + if self.model_args["model_name_or_path"] is not None: + if "t5" in self.model_args["model_name_or_path"]: # type:ignore + self.model = T5ForConditionalGeneration.from_pretrained( + self.model_args["model_name_or_path"], # type:ignore + cache_dir=self.cache_dir, + ) + elif "bart" in self.model_args["model_name_or_path"]: # type:ignore + self.model = BartForConditionalGeneration.from_pretrained( + self.model_args["model_name_or_path"], # type:ignore + cache_dir=self.cache_dir, + ) + else: + raise ValueError( + f"{self.model_args['model_name_or_path']} is not supported for conditional generation training." + ) + else: + if "t5" in self.model_args["model_config_name"]: + config = T5Config.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = T5ForConditionalGeneration.from_config(config) + elif "bart" in self.model_args["model_config_name"]: + config = BartConfig.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = BartForConditionalGeneration.from_config(config) + else: + raise ValueError( + f"{self.model_args['model_name_or_path']} is not supported for conditional generation training." + ) + + logger.info("Training from scratch") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_args["tokenizer"], use_fast=False + ) + + self.model.resize_token_embeddings(len(self.tokenizer)) # type: ignore + + +class CLMModule(LMModule): + """Pytorch lightning model for CLM training.""" + + def init_model(self) -> None: + """Initialize a CLM model.""" + + if self.model_args["model_name_or_path"] is not None: + self.model = AutoModelForCausalLM.from_pretrained( + self.model_args["model_name_or_path"], cache_dir=self.cache_dir + ) + else: + config = AutoConfig.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = AutoModelForCausalLM.from_config(config) + + logger.info("Training from scratch") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_args["tokenizer"], + sep_token="<|sep|>", + pad_token="<|pad|>", + use_fast=False, + ) + + self.model.resize_token_embeddings(len(self.tokenizer)) # type: ignore + + +class PLMModule(LMModule): + """Pytorch lightning model for PLM training.""" + + def init_model(self) -> None: + """Initialize a PLM model. """ + + if self.model_args["model_name_or_path"] is not None: + self.model = XLNetLMHeadModel.from_pretrained( + self.model_args["model_name_or_path"], # type:ignore + cache_dir=self.cache_dir, + ) + else: + config = AutoConfig.from_pretrained( + self.model_args["model_config_name"], cache_dir=self.cache_dir + ) + + self.model = XLNetLMHeadModel.from_config(config) + + logger.info("Training from scratch") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_args["tokenizer"], use_fast=False + ) + + self.model.resize_token_embeddings(len(self.tokenizer)) # type: ignore + + +LM_MODULE_FACTORY: Dict[str, Type[LMModule]] = { + "lm": LMModule, + "mlm": MLMModule, + "clm": CLMModule, + "plm": PLMModule, +} diff --git a/src/gt4sd/training_pipelines/terminator_training.json b/src/gt4sd/training_pipelines/terminator_training.json new file mode 100644 index 000000000..e177c289e --- /dev/null +++ b/src/gt4sd/training_pipelines/terminator_training.json @@ -0,0 +1,170 @@ +{ + "description": "Pipeline for training a Terminator model, a multi-task transformer for conditional molecular design.", + "config_name": { + "type": "string", + "description": "Name of the configuration to be used for training.", + "optional": false + }, + "dataloader_drop_last": { + "type": "boolean", + "description": "Whether to drop the last batch of the dataloader.", + "default": true, + "optional": true + }, + "disable_tqdm": { + "type": "boolean", + "description": "Whether to disable progress informatiopn in the logs.", + "default": true, + "optional": true + }, + "do_eval": { + "type": "boolean", + "description": "Whether to perform the evaluation of the model on the validation set.", + "default": true, + "optional": true + }, + "do_train": { + "type": "boolean", + "description": "Whether to perform the training.", + "default": true, + "optional": true + }, + "eval_accumulation_steps": { + "type": "integer", + "description": "Number of accumulation steps during evaluation.", + "default": 2, + "example": 2, + "optional": true + }, + "eval_data_file": { + "type": "string", + "description": "File used for model evaluation.", + "optional": false + }, + "eval_steps": { + "type": "integer", + "description": "Number of evaluation steps.", + "default": 1000, + "example": 1000, + "optional": true + }, + "evaluate_during_training": { + "type": "boolean", + "description": "Whether to evaluate during training.", + "default": true, + "example": true, + "optional": true + }, + "gradient_accumulation_steps": { + "type": "integer", + "description": "Number of gradient accumulation steps.", + "default": 2, + "example": 2, + "optional": true + }, + "learning_rate": { + "type": "number", + "description": "Learning rate for the training.", + "default": 0.0001, + "example": 0.0001, + "optional": true + }, + "line_by_line": { + "type": "boolean", + "description": "Whether the files for training and evaluation contain an example per line.", + "default": true, + "example": true, + "optional": true + + }, + "logging_steps": { + "type": "integer", + "description": "Number of steps in between logging cycles.", + "default": 50, + "example": 50, + "optional": true + }, + "max_span_length": { + "type": "integer", + "description": "Maximum length of a span.", + "default": 5, + "example": 5, + "optional": true + }, + "max_steps": { + "type": "integer", + "description": "Maximum number of steps.", + "default": 20000, + "example": 20000, + "optional": true + }, + "model_name_or_path": { + "type": "string", + "description": "The model checkpoint for weights initialization.", + "optional": false + }, + "num_train_epochs": { + "type": "integer", + "description": "Total number of training epochs to perform.", + "example": 10, + "optional": false + }, + "output_dir": { + "type": "string", + "description": "The output directory where the model predictions and checkpoints will be written.", + "example": "path/to/output", + "optional": false + }, + "overwrite_output_dir": { + "type": "boolean", + "description": "Whether the output directory files can be overwritten.", + "default": true, + "example": true, + "optional": true + }, + "per_device_eval_batch_size": { + "type": "integer", + "default": 32, + "description": "The batch size per GPU core/CPU for evaluation.", + "example": 32, + "optional": false + }, + "per_device_train_batch_size": { + "type": "integer", + "default": 32, + "description": "The batch size per GPU core/CPU for training.", + "example": 32, + "optional": false + }, + "train_data_file": { + "type": "string", + "description": "File used for model training.", + "optional": false + }, + "tokenizer_name": { + "type": "string", + "description": "Name of the tokenizer to be used.", + "optional": false + }, + "save_steps": { + "type": "integer", + "description": "Number of steps in between model checkpoint savings.", + "default": 50, + "example": 50, + "optional": true + }, + "plm_probability": { + "type": "number", + "description": "Probability of masking tokens during permutation language modeling.", + "default": 0.2, + "example": 0.2, + "optional": true + }, + "save_total_limit": { + "type": "integer", + "description": "Total number of model checkpoints stored.", + "default": 2, + "example": 2, + "optional": true + } +} \ No newline at end of file diff --git a/src/gt4sd/training_pipelines/tests/__init__.py b/src/gt4sd/training_pipelines/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/gt4sd/training_pipelines/tests/lm_example.jsonl b/src/gt4sd/training_pipelines/tests/lm_example.jsonl new file mode 100644 index 000000000..9f1704d08 --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/lm_example.jsonl @@ -0,0 +1,8 @@ +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} +{"text": "The present invention provides compounds suitable for use in the treatment of conditions where it is beneficial to halt bone loss and kill cancer cells, particularly in metastases to and primary tumours in the bone and surrounding tissues. Consequently the present invention provides compounds comprising a bisphosphonate moiety linked to a phytochemical, pharmaceutical compositions thereof and methods of treatment of bone diseases and/or proliferative disorders.", "claim": "A compound of formula Q-T-L wherein Q is a bisphosphonate moiety, T is linker and L is an anti-osteolytic or osteoinductive phytochemical."} diff --git a/src/gt4sd/training_pipelines/tests/molecules.smi b/src/gt4sd/training_pipelines/tests/molecules.smi new file mode 100644 index 000000000..478175220 --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/molecules.smi @@ -0,0 +1,64 @@ +CCO CHEMBL545 +C CHEMBL17564 +CO CHEMBL14688 +NCCS CHEMBL602 +NCCN CHEMBL816 +CN CHEMBL43280 +C=O CHEMBL1255 +CCN CHEMBL14449 +CSC CHEMBL15580 +CBr CHEMBL48339 +CI CHEMBL115849 +CF CHEMBL116838 +CC CHEMBL135626 +CNC=O CHEMBL9240 +CCCN CHEMBL14409 +CCCO CHEMBL14687 +O=CC#C CHEMBL722 +C=CC=O CHEMBL721 +CC#N CHEMBL45211 +CCCl CHEMBL46058 +NC#N CHEMBL56279 +CC=O CHEMBL76101 +SC#N CHEMBL84336 +FCF CHEMBL115186 +C#C CHEMBL116336 +CCl CHEMBL117545 +C=C CHEMBL117822 +COC CHEMBL119178 +CNC CHEMBL120433 +CCNCC CHEMBL1189 +CCC CHEMBL135416 +N#N CHEMBL142438 +CNO CHEMBL144761 +CNN CHEMBL160520 +C#N CHEMBL183419 +CC(C)O CHEMBL582 +CNC=O CHEMBL9081 +CCCCON CHEMBL6960 +CCNC=O CHEMBL9421 +CC(O)=O CHEMBL539 +CCCCO CHEMBL14245 +CCCCN CHEMBL13968 +COCOC CHEMBL15537 +CCC#N CHEMBL15871 +CCCCC CHEMBL16102 +CCOCC CHEMBL16264 +NC(N)=N CHEMBL821 +ClCCl CHEMBL45967 +NCC=C CHEMBL57286 +NC(N)=O CHEMBL985 +NCCO CHEMBL104943 +OCCF CHEMBL115586 +CC=C CHEMBL117213 +OC=O CHEMBL116736 +CC#C CHEMBL116902 +CCCC CHEMBL134702 +CCBr CHEMBL156378 +CNNC CHEMBL162921 +CC=O CHEMBL170365 +OCCS CHEMBL254951 +NC=O CHEMBL266160 +ON=C CHEMBL324784 +OCCO CHEMBL457299 +CON CHEMBL1213633 diff --git a/src/gt4sd/training_pipelines/tests/test_argument_parser.py b/src/gt4sd/training_pipelines/tests/test_argument_parser.py new file mode 100644 index 000000000..320a73fd0 --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/test_argument_parser.py @@ -0,0 +1,188 @@ +"""Argument parser unit tests.""" + +from dataclasses import dataclass, field +from typing import Union + +from gt4sd.cli.argument_parser import ArgumentParser + + +@dataclass +class TestArguments: + + int_arg: int = field(default=0) + + float_arg: float = field(default=0.0) + + str_arg: str = field(default="test") + + bool_arg: bool = field(default=True) + + int_none_arg: Union[int, None] = field(default=None) + + float_none_arg: Union[float, None] = field(default=None) + + str_none_arg: Union[str, None] = field(default=None) + + bool_none_arg: Union[bool, None] = field(default=None) + + +def test_int_default(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert isinstance(args[0].int_arg, int) + assert args[0].int_arg == 0 + + +def test_float_default(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert isinstance(args[0].float_arg, float) + assert args[0].float_arg == 0.0 + + +def test_str_default(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert isinstance(args[0].str_arg, str) + assert args[0].str_arg == "test" + + +def test_bool_default(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert isinstance(args[0].bool_arg, bool) + assert args[0].bool_arg is True + + +def test_int_assigned(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--int_arg", "1"]) + + assert isinstance(args[0].int_arg, int) + assert args[0].int_arg == 1 + + +def test_float_assigned(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--float_arg", "1.0"]) + + assert isinstance(args[0].float_arg, float) + assert args[0].float_arg == 1.0 + + +def test_str_assigned(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--str_arg", "my_test"]) + + assert isinstance(args[0].str_arg, str) + assert args[0].str_arg == "my_test" + + +def test_bool_assigned(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--bool_arg", "False"]) + + assert isinstance(args[0].bool_arg, bool) + assert args[0].bool_arg is False + + +def test_bool_int_assigned(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--bool_arg", "0"]) + + assert isinstance(args[0].bool_arg, bool) + assert args[0].bool_arg is False + + +def test_int_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert args[0].int_none_arg is None + + +def test_float_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert args[0].float_none_arg is None + + +def test_str_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert args[0].str_none_arg is None + + +def test_bool_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses([]) + + assert args[0].bool_none_arg is None + + +def test_int_str_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--int_none_arg", ""]) + + assert args[0].int_none_arg is None + + +def test_float_str_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--float_none_arg", ""]) + + assert args[0].float_none_arg is None + + +def test_str_str_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--str_none_arg", ""]) + + assert args[0].str_none_arg is None + + +def test_bool_str_none(): + + parser = ArgumentParser((TestArguments)) # type: ignore + + args = parser.parse_args_into_dataclasses(["--bool_none_arg", ""]) + + assert args[0].bool_none_arg is None diff --git a/src/gt4sd/training_pipelines/tests/test_training_language_modeling.py b/src/gt4sd/training_pipelines/tests/test_training_language_modeling.py new file mode 100644 index 000000000..7194051d2 --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/test_training_language_modeling.py @@ -0,0 +1,214 @@ +"""Language modeling trainer unit tests.""" + +from typing import cast + +import pkg_resources +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +from gt4sd.training_pipelines import ( + TRAINING_PIPELINE_MAPPING, + LanguageModelingTrainingPipeline, +) +from gt4sd.training_pipelines.pytorch_lightning.language_modeling.lm_datasets import ( + CGMDataModule, + CLMDataModule, + MLMDataModule, + PLMDataModule, +) +from gt4sd.training_pipelines.pytorch_lightning.language_modeling.models import ( + CGMModule, + CLMModule, + MLMModule, + PLMModule, +) + +template_config = { + "model_args": { + "tokenizer": "albert-base-v2", + "model_name_or_path": "albert-base-v2", + "model_config_name": "albert-base-v2", + "type": "mlm", + "lr": 2e-5, + "lr_decay": 0.5, + "cache_dir": "/tmp/dci_", + }, + "dataset_args": { + "max_length": 512, + "mlm_probability": 0.15, + "plm_probability": 0.1666, + "max_span_length": 5, + "batch_size": 8, + "train_file": "ade_corpus_v2", + "validation_file": "ade_corpus_v2", + "cache_dir": "/tmp/dci_", + }, + "trainer_args": { + "default_root_dir": "here", + "val_check_interval": 5000, + "max_epochs": 1, + "accumulate_grad_batches": 1, + "limit_val_batches": 500, + "log_every_n_steps": 500, + "monitor": "val_loss", + "save_top_k": 2, + "mode": "min", + "every_n_train_steps": 50000, + }, +} + + +def test_get_data_and_model_modules_mlm(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("language-modeling-trainer") + + assert pipeline is not None + + test_pipeline = cast(LanguageModelingTrainingPipeline, pipeline()) + + config = template_config.copy() + + file_path = pkg_resources.resource_filename( + "gt4sd", + "training_pipelines/tests/lm_example.jsonl", + ) + + config["dataset_args"]["train_file"] = file_path + + config["dataset_args"]["validation_file"] = file_path + config["model_args"]["type"] = "mlm" + + data_module, model_module = test_pipeline.get_data_and_model_modules( + config["model_args"], config["dataset_args"] # type: ignore + ) + + assert isinstance(model_module, MLMModule) + assert isinstance(data_module, MLMDataModule) + + check_model_config(model_module, config["model_args"]) + check_data_config(data_module, config["dataset_args"]) + + +def test_get_data_and_model_modules_clm(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("language-modeling-trainer") + + assert pipeline is not None + + test_pipeline = cast(LanguageModelingTrainingPipeline, pipeline()) + + config = template_config.copy() + + file_path = pkg_resources.resource_filename( + "gt4sd", + "training_pipelines/tests/lm_example.jsonl", + ) + + config["dataset_args"]["train_file"] = file_path + config["dataset_args"]["validation_file"] = file_path + config["model_args"]["type"] = "clm" + config["model_args"]["model_name_or_path"] = "gpt2" + + data_module, model_module = test_pipeline.get_data_and_model_modules( + config["model_args"], config["dataset_args"] # type: ignore + ) + + assert isinstance(model_module, CLMModule) + assert isinstance(data_module, CLMDataModule) + + check_model_config(model_module, config["model_args"]) + check_data_config(data_module, config["dataset_args"]) + + +def test_get_data_and_model_modules_cgm(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("language-modeling-trainer") + + assert pipeline is not None + + test_pipeline = cast(LanguageModelingTrainingPipeline, pipeline()) + + config = template_config.copy() + + file_path = pkg_resources.resource_filename( + "gt4sd", + "training_pipelines/tests/lm_example.jsonl", + ) + + config["dataset_args"]["train_file"] = file_path + config["dataset_args"]["validation_file"] = file_path + config["model_args"]["type"] = "cgm" + config["model_args"]["model_name_or_path"] = "t5-base" + + data_module, model_module = test_pipeline.get_data_and_model_modules( + config["model_args"], config["dataset_args"] # type: ignore + ) + + assert isinstance(model_module, CGMModule) + assert isinstance(data_module, CGMDataModule) + + check_model_config(model_module, config["model_args"]) + check_data_config(data_module, config["dataset_args"]) + + +def test_get_data_and_model_modules_plm(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("language-modeling-trainer") + + assert pipeline is not None + + test_pipeline = cast(LanguageModelingTrainingPipeline, pipeline()) + + config = template_config.copy() + + file_path = pkg_resources.resource_filename( + "gt4sd", + "training_pipelines/tests/lm_example.jsonl", + ) + + config["dataset_args"]["train_file"] = file_path + config["dataset_args"]["validation_file"] = file_path + config["model_args"]["type"] = "plm" + config["model_args"]["model_name_or_path"] = "xlnet-base-cased" + + data_module, model_module = test_pipeline.get_data_and_model_modules( + config["model_args"], config["dataset_args"] # type: ignore + ) + + assert isinstance(model_module, PLMModule) + assert isinstance(data_module, PLMDataModule) + + check_model_config(model_module, config["model_args"]) + check_data_config(data_module, config["dataset_args"]) + + +def test_add_callbacks(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("language-modeling-trainer") + + assert pipeline is not None + + callbacks_input = { + "model_checkpoint_callback": { + "monitor": "val_loss", + "save_top_k": 2, + "mode": "min", + "every_n_train_steps": 50000, + } + } + + callbacks = pipeline().add_callbacks(callbacks_input) # type: ignore + + assert len(callbacks) == 1 + assert isinstance(callbacks[0], ModelCheckpoint) + + +def check_model_config(module, config): + for entry in module.model_args: + assert entry in config + assert config[entry] == module.model_args[entry] + + +def check_data_config(module, config): + for entry in module.dataset_args: + assert entry in config + assert config[entry] == module.dataset_args[entry] diff --git a/src/gt4sd/training_pipelines/tests/test_training_paccmann_vae.py b/src/gt4sd/training_pipelines/tests/test_training_paccmann_vae.py new file mode 100644 index 000000000..7f3e7764b --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/test_training_paccmann_vae.py @@ -0,0 +1,70 @@ +"""Language modeling trainer unit tests.""" + +from typing import Any, Dict, cast + +import pkg_resources + +from gt4sd.training_pipelines import ( + TRAINING_PIPELINE_MAPPING, + PaccMannVAETrainingPipeline, +) + +template_config = { + "model_args": { + "n_layers": 1, + "bidirectional": False, + "rnn_cell_size": 64, + "latent_dim": 32, + "stack_width": 8, + "stack_depth": 8, + "decoder_search": "sampling", + "dropout": 0.2, + "generate_len": 50, + "kl_growth": 0.003, + "input_keep": 0.85, + "test_input_keep": 1.0, + "temperature": 0.8, + "embedding": "one_hot", + "batch_mode": "packed", + "vocab_size": 380, + "pad_index": 0, + "embedding_size": 380, + }, + "dataset_args": { + "add_start_stop_token": True, + "selfies": True, + "num_workers": 1, + "pin_memory": False, + }, + "training_args": { + "epochs": 1, + "batch_size": 4, + "learning_rate": 0.0005, + "optimizer": "adam", + "log_interval": 2, + "save_interval": 2, + "eval_interval": 2, + "model_path": "/tmp/paccmann_vae", + "training_name": "paccmann-vae-test", + }, +} + + +def test_train(): + + pipeline = TRAINING_PIPELINE_MAPPING.get("paccmann-vae-trainer") + + assert pipeline is not None + + test_pipeline = cast(PaccMannVAETrainingPipeline, pipeline()) + + config: Dict[str, Any] = template_config.copy() + + file_path = pkg_resources.resource_filename( + "gt4sd", + "training_pipelines/tests/molecules.smi", + ) + + config["dataset_args"]["train_smiles_filepath"] = file_path + config["dataset_args"]["test_smiles_filepath"] = file_path + test_pipeline.train(**config) diff --git a/src/gt4sd/training_pipelines/tests/test_training_pipelines.py b/src/gt4sd/training_pipelines/tests/test_training_pipelines.py new file mode 100644 index 000000000..3baaa14fe --- /dev/null +++ b/src/gt4sd/training_pipelines/tests/test_training_pipelines.py @@ -0,0 +1,42 @@ +"""Exceptions tests.""" + +from gt4sd.training_pipelines import ( + TRAINING_PIPELINE_ARGUMENTS_MAPPING, + TRAINING_PIPELINE_NAME_METADATA_MAPPING, + training_pipeline_name_to_metadata, +) + + +def test_metadata_retrieval_for_registered_pipelines_from_json(): + for name, filename in TRAINING_PIPELINE_NAME_METADATA_MAPPING.items(): + pipeline_metadata = training_pipeline_name_to_metadata(name) + assert pipeline_metadata["training_pipeline"] == name + assert "description" in pipeline_metadata + assert "parameters" in pipeline_metadata + assert "description" not in pipeline_metadata["parameters"] + + +def test_metadata_retrieval_for_registered_pipelines_from_dataclass(): + for name, filename in TRAINING_PIPELINE_ARGUMENTS_MAPPING.items(): + pipeline_metadata = training_pipeline_name_to_metadata(name) + assert pipeline_metadata["training_pipeline"] == name + assert "description" in pipeline_metadata + assert "parameters" in pipeline_metadata + assert "description" not in pipeline_metadata["parameters"] + + for parameter in pipeline_metadata["parameters"]: + assert "description" in pipeline_metadata["parameters"][parameter] + assert "type" in pipeline_metadata["parameters"][parameter] + + assert len(pipeline_metadata["parameters"][parameter]) <= 3 + + if len(pipeline_metadata["parameters"][parameter]) == 3: + assert "default" in pipeline_metadata["parameters"][parameter] + + +def test_metadata_retrieval_for_unregistered_pipelines(): + name = "this pipeline does not exists and can't be registered" + pipeline_metadata = training_pipeline_name_to_metadata(name) + assert pipeline_metadata["training_pipeline"] == name + assert pipeline_metadata["description"] == "A training pipeline." + assert pipeline_metadata["parameters"] == {}