Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing trt_compile() and example how to use it #407

Open
wants to merge 6 commits into
base: bionemo1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bionemo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from dataclasses import asdict
from typing import Any, Sequence, Union

from bionemo.utils.trt_compiler import trt_compile


__all__: Sequence[str] = (
"update_dataclass_config",
Expand Down
194 changes: 194 additions & 0 deletions bionemo/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities and types for defining networks, these depend on PyTorch.
"""

from __future__ import annotations

from collections.abc import Callable
from importlib import import_module
from typing import Any


OPTIONAL_IMPORT_MSG_FMT = "{}"


def min_version(the_module: Any, min_version_str: str = "", *_args: Any) -> bool:
"""
Convert version strings into tuples of int and compare them.

Returns True if the module's version is greater or equal to the 'min_version'.
When min_version_str is not provided, it always returns True.
"""
if not min_version_str or not hasattr(the_module, "__version__"):
return True # always valid version

mod_version = tuple(int(x) for x in the_module.__version__.split(".")[:2])
required = tuple(int(x) for x in min_version_str.split(".")[:2])
return mod_version >= required


def exact_version(the_module: Any, version_str: str = "", *_args: Any) -> bool:
"""
Returns True if the module's __version__ matches version_str
"""
if not hasattr(the_module, "__version__"):
return False
return bool(the_module.__version__ == version_str)


class InvalidPyTorchVersionError(Exception):
"""
Raised when called function or method requires a more recent
PyTorch version than that installed.
"""

def __init__(self, required_version, name):
message = f"{name} requires PyTorch version {required_version} or later"
super().__init__(message)


class OptionalImportError(ImportError):
"""
Could not import APIs from an optional dependency.
"""


def optional_import(
module: str,
version: str = "",
version_checker: Callable[..., bool] = min_version,
name: str = "",
descriptor: str = OPTIONAL_IMPORT_MSG_FMT,
version_args: Any = None,
allow_namespace_pkg: bool = False,
as_type: str = "default",
) -> tuple[Any, bool]:
"""
Imports an optional module specified by `module` string.
Any importing related exceptions will be stored, and exceptions raise lazily
when attempting to use the failed-to-import module.

Args:
module: name of the module to be imported.
version: version string used by the version_checker.
version_checker: a callable to check the module version, Defaults to monai.utils.min_version.
name: a non-module attribute (such as method/class) to import from the imported module.
descriptor: a format string for the final error message when using a not imported module.
version_args: additional parameters to the version checker.
allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False.
as_type: there are cases where the optionally imported object is used as
a base class, or a decorator, the exceptions should raise accordingly. The current supported values
are "default" (call once to raise), "decorator" (call the constructor and the second call to raise),
and anything else will return a lazy class that can be used as a base class (call the constructor to raise).

Returns:
The imported module and a boolean flag indicating whether the import is successful.

Examples::

>>> torch, flag = optional_import('torch', '1.1')
>>> print(torch, flag)
<module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True

>>> the_module, flag = optional_import('unknown_module')
>>> print(flag)
False
>>> the_module.method # trying to access a module which is not imported
OptionalImportError: import unknown_module (No module named 'unknown_module').

>>> torch, flag = optional_import('torch', '42', exact_version)
>>> torch.nn # trying to access a module for which there isn't a proper version imported
OptionalImportError: import torch (requires version '42' by 'exact_version').

>>> conv, flag = optional_import('torch.nn.functional', '1.0', name='conv1d')
>>> print(conv)
<built-in method conv1d of type object at 0x11a49eac0>

>>> conv, flag = optional_import('torch.nn.functional', '42', name='conv1d')
>>> conv() # trying to use a function from the not successfully imported module (due to unmatched version)
OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
"""

tb = None
exception_str = ""
if name:
actual_cmd = f"from {module} import {name}"
else:
actual_cmd = f"import {module}"
try:
pkg = __import__(module) # top level module
the_module = import_module(module)
if not allow_namespace_pkg:
is_namespace = getattr(the_module, "__file__", None) is None and hasattr(the_module, "__path__")
if is_namespace:
raise AssertionError
if name: # user specified to load class/function/... from the module
the_module = getattr(the_module, name)
except Exception as import_exception: # any exceptions during import
tb = import_exception.__traceback__
exception_str = f"{import_exception}"
else: # found the module
if version_args and version_checker(pkg, f"{version}", version_args):
return the_module, True
if not version_args and version_checker(pkg, f"{version}"):
return the_module, True

# preparing lazy error message
msg = descriptor.format(actual_cmd)
if version and tb is None: # a pure version issue
msg += f" (requires '{module} {version}' by '{version_checker.__name__}')"
if exception_str:
msg += f" ({exception_str})"

class _LazyRaise:
def __init__(self, *_args, **_kwargs):
_default_msg = (
f"{msg}."
+ "\n\nFor details about installing the optional dependencies, please visit:"
+ "\n https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies"
)
if tb is None:
self._exception = OptionalImportError(_default_msg)
else:
self._exception = OptionalImportError(_default_msg).with_traceback(tb)

def __getattr__(self, name):
"""
Raises:
OptionalImportError: When you call this method.
"""
raise self._exception

def __call__(self, *_args, **_kwargs):
"""
Raises:
OptionalImportError: When you call this method.
"""
raise self._exception

def __getitem__(self, item):
raise self._exception

def __iter__(self):
raise self._exception

if as_type == "default":
return _LazyRaise(), False

class _LazyCls(_LazyRaise):
def __init__(self, *_args, **kwargs):
super().__init__()
if not as_type.startswith("decorator"):
raise self._exception

return _LazyCls, False
Loading
Loading