Skip to content

Commit

Permalink
feat: integrate diffusion models in gt4sd (#128)
Browse files Browse the repository at this point in the history
* wip: skeleton diffusion

* wip: add implementation, core

* wip: add diffusion models and noise scheduler support

* wip: add tests

* fix: fix imports

* wip: black files

* feat: add modality, minor fix

* fix: minor fix tests

* fix tests

* add tests

* add training_pipeline for diffusion

* minor change

* remove double reqs

* set ubuntu version ci

* ci: relax latest, fix examples

* add diffusion-trainer, test diffusion training_pipeline

* skip high-memory tests

* minor change

* fix: segfaults on ulinux for rdkit/torchvision conflicts

* fix tests, add demo

* fix model_types

* improve tests

* feat: add notebook for text2image, improve testing

* fix tests, skip auth_token, skip slow_sampling

* separate tests unconditional and text-conditional

* fix reqs

* add tests for diffusion training_pipeline

* add stable-diffusion in notebook, use conditional reqs

* fix conda.yml path

* minor changes, explicit name training_pipeline for image generation

* fix reqs

* change training_pipeline name for consistency

* use common interface for sampling, add target, improve notebook, fix issue with hashable

* feat: simplifying sampling logic.

Signed-off-by: Matteo Manica <[email protected]>

* fix: pass prompt in configuration

* fix style, update notebook

* black files

* refactor sample

Signed-off-by: Matteo Manica <[email protected]>
Co-authored-by: Jannis Born <[email protected]>
Co-authored-by: Matteo Manica <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2022
1 parent 10319c4 commit 94de023
Show file tree
Hide file tree
Showing 20 changed files with 1,905 additions and 80 deletions.
16 changes: 0 additions & 16 deletions .github/conda_ci.yml

This file was deleted.

34 changes: 0 additions & 34 deletions .github/requirements_ci.txt

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: gt4sd
environment-file: .github/conda_ci.yml
environment-file: conda.yml
auto-activate-base: false
use-only-tar-bz2: true
- name: Install gt4sd from source
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,5 +145,5 @@ oracle
# local tests
_main_*

# data
# data folder
data/
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ The trainer currently supports the following training pipelines:
- `moses-vae-trainer`: Moses VAE models.
- `torchdrug-gcpn-trainer`: TorchDrug Graph Convolutional Policy Network model.
- `torchdrug-graphaf-trainer`: TorchDrug autoregressive GraphAF model.
- `diffusion-trainer`: Diffusers model.
- `gflownet-trainer`: GFlowNet model.

```console
Expand Down Expand Up @@ -337,6 +338,7 @@ Beyond implementing various generative modeling inference and training pipelines
- [TAPE](https://github.com/songlab-cal/tape): encoder modules compatible with the protein language models.
- [PaccMann](https://github.com/PaccMann/): inference pipelines for all algorithms of the PaccMann family as well as training pipelines for the generative VAEs.
- [transformers](https://huggingface.co/transformers): training and inference pipelines for generative models from [HuggingFace Models](https://huggingface.co/models)
- [diffusers](https://github.com/huggingface/diffusers): training and inference pipelines for generative models from [Diffusers Models](https://github.com/huggingface/diffusers)
- [GFlowNets](https://github.com/recursionpharma/gflownet): training and inference pipeline for [Generative Flow Networks](https://yoshuabengio.org/2022/03/05/generative-flow-networks/)

## References
Expand Down
6 changes: 4 additions & 2 deletions examples/granular/ae_mlp_example/example_train_ae_and_mlp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from gt4sd.frameworks.granular.train.core import parse_arguments_from_config, train_granular
from gt4sd.frameworks.granular.train.core import (
parse_arguments_from_config,
train_granular,
)

args = parse_arguments_from_config("config_ae.ini")

train_granular(vars(args))

489 changes: 489 additions & 0 deletions notebooks/diffusion-demo.ipynb

Large diffs are not rendered by default.

24 changes: 16 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# pypi requirements
accelerate>=0.12
datasets>=1.11.0
diffusers>=0.2.4
joblib>=1.1.0
keras==2.3.1
torchmetrics<0.7
keybert==0.2.0
minio==7.0.1
modlamp>=4.0.0
scikit-learn<0.24.0
numpy>=1.16.5
protobuf<3.20
pytorch_lightning<=1.5.0
pydantic>=1.7.3,<=1.9.2
PyTDC>=0.3.6
Expand All @@ -16,19 +17,26 @@ rdkit-pypi>=2020.9.5.2,<=2021.9.4
regex>=2.5.91
reinvent-chemistry==0.0.38
sacremoses>=0.0.41
scikit-learn<0.24.0
scikit-optimize>=0.8.1
sentencepiece>=0.1.95
sympy>=1.10.1
tables>=3.7.0
tape-proteins>=0.4
protobuf<3.20
tensorboard!=2.5.0,>=2.2.0
tensorflow==2.1.0
torch>=1.0
torch-cluster>=1.6.0
torch-geometric>=2.0.4
torch-sparse>=0.6.14
# cpu for ci tests
torch>=1.0+cpu; sys_platform == "linux"
torch>=1.0; sys_platform != "linux"
torch-cluster>=1.6.0+cpu; sys_platform == 'linux'
torch-cluster>=1.6.0; sys_platform != 'linux'
torch-geometric>=2.0.4+cpu; sys_platform == 'linux'
torch-geometric>=2.0.4; sys_platform != 'linux'
torch-sparse>=0.6.14+cpu; sys_platform == 'linux'
torch-sparse>=0.6.14; sys_platform != 'linux'
torchdrug>=0.1.2
torchmetrics<0.7
torchvision>=0.12.0
transformers>=4.2.1
typing_extensions>=3.7.4.3
wheel>=0.26
wheel>=0.26
17 changes: 16 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ ignore_missing_imports = True
[mypy-modlamp.*]
ignore_missing_imports = True

[mypy-diffusers.*]
ignore_missing_imports = True

[mypy-accelerate.*]
ignore_missing_imports = True

[mypy-tqdm.*]
ignore_missing_imports = True

[mypy-torchvision.*]
ignore_missing_imports = True

[mypy-PIL.*]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

Expand All @@ -204,4 +219,4 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-sympy.*]
ignore_missing_imports = True
ignore_missing_imports = True
8 changes: 8 additions & 0 deletions src/gt4sd/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
CatalystGenerator,
)
from .controlled_sampling.paccmann_gp.core import PaccMannGPGenerator # noqa: F401
from .generation.diffusion.core import ( # noqa: F401
DDIMGenerator,
DDPMGenerator,
LDMGenerator,
LDMTextToImageGenerator,
ScoreSdeGenerator,
StableDiffusionGenerator,
)
from .generation.hugging_face.core import ( # noqa: F401
HuggingFaceCTRLGenerator,
HuggingFaceGPT2Generator,
Expand Down
40 changes: 24 additions & 16 deletions src/gt4sd/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from __future__ import annotations

import collections
import logging
import os
import shutil
Expand Down Expand Up @@ -224,23 +225,30 @@ def raise_if_none_sampled(items: set, detail: str):
try:
while True:
generated_items = self.generate() # type:ignore
for item in generated_items:
if item in item_set:
continue
else:
try:
valid_item = self.configuration.validate_item(item)
for index, item in enumerate(generated_items):
try:
valid_item = self.configuration.validate_item(item)
# check if sample is hashable
if not isinstance(item, collections.Hashable):
yield valid_item
item_set.add(item)
if len(item_set) == number_of_items:
signal.alarm(0)
return
except InvalidItem as error:
logger.debug(
f"item {item} could not be validated, "
f"raising {error.title}: {error.detail}"
)
continue
item_set.add(str(index))
else:
# validation for samples represented as strings
if item in item_set:
continue
else:
yield valid_item
item_set.add(item) # type:ignore
if len(item_set) == number_of_items:
signal.alarm(0)
return
except InvalidItem as error:
logger.debug(
f"item {item} could not be validated, "
f"raising {error.title}: {error.detail}"
)
continue

# make sure we don't keep sampling more than a given number of times,
# in case no new items are generated.
if len(item_set) == item_set_length:
Expand Down
45 changes: 45 additions & 0 deletions src/gt4sd/algorithms/generation/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""DiffusersGenerationAlgorithm initialization."""

from .core import (
DDIMGenerator,
DDPMGenerator,
DiffusersGenerationAlgorithm,
LDMGenerator,
LDMTextToImageGenerator,
ScoreSdeGenerator,
StableDiffusionGenerator,
)

__all__ = [
"DiffusersGenerationAlgorithm",
"DDPMGenerator",
"DDIMGenerator",
"DiffusionGenerator",
"StableDiffusionGenerator",
"ScoreSdeGenerator",
"LDMGenerator",
"LDMTextToImageGenerator",
]
Loading

0 comments on commit 94de023

Please sign in to comment.