Skip to content

Commit

Permalink
chore: ensuring order of tensorflow imports. (#218)
Browse files Browse the repository at this point in the history
* chore: ensuring order of tensorflow imports.

Upperbounding tensorflow-io-gcs-filesystem and importing before lightning to avoid segfaults.

Signed-off-by: Matteo Manica <[email protected]>

* style: fixing mypy configuration and import order.

Signed-off-by: Matteo Manica <[email protected]>

* Fix Segmentation faults for GPU installation (#219)

* fix: fix GPU segfault by order-forcing

* ci: temporarily removing isort check

* chore: handling isort removal.

Signed-off-by: Matteo Manica <[email protected]>

---------

Signed-off-by: Matteo Manica <[email protected]>
Co-authored-by: Jannis Born <[email protected]>
  • Loading branch information
drugilsberg and jannisborn authored May 5, 2023
1 parent 225879c commit 63970d7
Show file tree
Hide file tree
Showing 19 changed files with 96 additions and 25 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ jobs:
run: |
conda activate gt4sd
python -m black src/gt4sd --check --diff --color
- name: Check isort
run: |
conda activate gt4sd
python -m isort src/gt4sd --check-only
# - name: Check isort
# run: |
# conda activate gt4sd
# python -m isort src/gt4sd --check-only
- name: Check flake8
run: |
conda activate gt4sd
Expand Down
1 change: 0 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ If you would like to contribute to the package, we recommend the following devel
```sh
# blacking and sorting imports (this might change your files)
python -m black src/gt4sd
python -m isort src/gt4sd
# checking flake8 and mypy
python -m flake8 --disable-noqa --per-file-ignores="__init__.py:F401" src/gt4sd
python -m mypy src/gt4sd
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ docutils==0.17.1
flake8==3.8.4
flask==1.1.2
flask_login==0.5.0
isort==5.7.0
# isort==5.7.0
licenseheaders==0.8.8
mypy==0.950
myst-parser==0.13.3
Expand Down
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ line-length = 88
skip-string-normalization = false
target-version = ['py37']

[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 88
force_to_top = ["sentencepiece", "rdkit", "scikit-learn"]
# [tool.isort]
# multi_line_output = 3
# include_trailing_comma = true
# force_grid_wrap = 0
# use_parentheses = true
# ensure_newline_before_comments = true
# line_length = 88
# force_to_top = ["sentencepiece", "torch", "tensorflow", "rdkit", "scikit-learn"]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ tape-proteins>=0.4
tensorboard!=2.5.0,>=2.2.0,<2.11.0
tensorboard-data-server<=0.6.1
tensorflow>=2.1.0,<2.11.0
tensorflow-io-gcs-filesystem<0.32.0
torchdrug>=0.2.0
torchmetrics>=0.7.0
transformers>=4.22.0,<=4.24.0
Expand Down
9 changes: 8 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ install_requires =
tables
tape-proteins
tensorboard
tensorflow
torch
torchdrug
torchmetrics
Expand Down Expand Up @@ -273,4 +274,10 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-gt4sd_trainer.hf_pl.*]
ignore_missing_imports = True
ignore_missing_imports = True

[mypy-tensorflow.*]
ignore_missing_imports = True

[mypy-ruamel.*]
ignore_missing_imports = True
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/gflownet/arg_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@
from typing import Any, Dict, Optional

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import numpy as np
from pytorch_lightning import Trainer

from ..ml.models import MODEL_FACTORY
from .utils import convert_string_to_class

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch


def parse_arguments_from_config(conf_file: Optional[str] = None) -> argparse.Namespace:
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/gflownet/dataloader/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from typing import Any, Dict, Optional

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
Expand All @@ -38,8 +40,10 @@
from ..ml.models import MODEL_FACTORY
from .sampler import SamplingIterator

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/gflownet/ml/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from typing import Any, Dict, List, Optional, Tuple

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -39,8 +41,10 @@
from ..dataloader.dataset import GFlowNetDataset, GFlowNetTask
from ..envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
8 changes: 8 additions & 0 deletions src/gt4sd/frameworks/gflownet/tests/test_gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#
from argparse import Namespace

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import numpy as np
import pytest
import pytorch_lightning as pl
Expand All @@ -35,6 +38,11 @@
from gt4sd.frameworks.gflownet.ml.module import GFlowNetModule
from gt4sd.frameworks.gflownet.tests.qm9 import QM9Dataset, QM9GapTask

# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

configuration = {
"bootstrap_own_reward": False,
"learning_rate": 1e-4,
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/gflownet/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing import Any, Dict

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
Expand All @@ -46,8 +48,10 @@

# from ..train import build_task

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/granular/arg_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
from typing import Any, Dict, Optional

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
from pytorch_lightning import Trainer

from ..ml.models import ARCHITECTURE_FACTORY
from .utils import convert_string_to_class

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch


def parse_arguments_from_config(conf_file: Optional[str] = None) -> argparse.Namespace:
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/granular/dataloader/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from typing import Callable, List, Optional, cast

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
Expand All @@ -35,8 +37,10 @@
from .dataset import CombinedGranularDataset, GranularDataset
from .sampler import StratifiedSampler

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/granular/ml/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@
from typing import Any, Callable, Dict, List, Tuple, cast

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch

from .models import GranularBaseModel, GranularEncoderDecoderModel
from .models.model_builder import building_models, define_latent_models_input_size

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch


class GranularModule(pl.LightningModule):
Expand Down
6 changes: 5 additions & 1 deletion src/gt4sd/frameworks/granular/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing import Any, Dict

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
Expand All @@ -38,8 +40,10 @@
from ..ml.models import AUTOENCODER_ARCHITECTURES
from ..ml.module import GranularModule

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
8 changes: 8 additions & 0 deletions src/gt4sd/training_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import logging
from typing import Any, Dict

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
from gt4sd_trainer.hf_pl.core import (
LanguageModelingDataArguments,
LanguageModelingModelArguments,
Expand Down Expand Up @@ -126,6 +129,11 @@
TorchDrugGraphAFTrainingPipeline,
)

# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing import Any, Dict, Optional, Tuple, Union

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
from gt4sd_trainer.hf_pl.pytorch_lightning_trainer import (
PytorchLightningTrainingArguments,
PyTorchLightningTrainingPipeline,
Expand All @@ -49,8 +51,10 @@
from ....frameworks.gflownet.ml.module import GFlowNetModule
from ...core import TrainingPipelineArguments

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from typing import Any, Dict, Optional, Tuple

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
from gt4sd_trainer.hf_pl.pytorch_lightning_trainer import (
PytorchLightningTrainingArguments,
PyTorchLightningTrainingPipeline,
Expand All @@ -42,8 +44,10 @@
from ....frameworks.granular.ml.module import GranularModule
from ...core import TrainingPipelineArguments

# sentencepiece has to be loaded before lightning to avoid segfaults
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import importlib_resources
from gt4sd_molformer.finetune.finetune_pubchem_light import (
LightningModule as RegressionLightningModule,
Expand Down Expand Up @@ -58,6 +61,11 @@

from ...core import TrainingPipelineArguments

# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down

0 comments on commit 63970d7

Please sign in to comment.