-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add post-install command to build PyTorch CPP extensions from within …
…onnxruntime package (#8027) ORTModule requires two PyTorch CPP extensions that are currently JIT compiled. The runtime compilation can cause issues in some environments without all build requirements or in environments with multiple instances of ORTModule running in parallel This PR creates a custom command to compile such extensions that must be manually executed before ORTModule is executed for the first time. When users try to use ORTModule before the extensions are compiled, an error with instructions are raised PyTorch CPP Extensions for ORTModule can be compiled by running: python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install Full build environment is needed for this
- Loading branch information
Thiago Crepaldi
authored
Jun 29, 2021
1 parent
25db570
commit 83be375
Showing
36 changed files
with
333 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
|
||
"""Support for PyTorch C++ extensions within ORTModule | ||
TODO: Implement mechanism to register extensions and prevent issues with incorrect/missing flags | ||
for each :meth:`torch.utils.cpp_extension.*` call | ||
""" | ||
|
||
import threading | ||
from functools import wraps | ||
from onnxruntime.capi import _pybind_state as C | ||
|
||
|
||
def run_once_aten_op_executor(f): | ||
""" | ||
Decorator to run a function only once. | ||
:param f: function to be run only once during execution time despite the number of calls | ||
:return: The original function with the params passed to it if it hasn't already been run before | ||
""" | ||
@wraps(f) | ||
def aten_op_executor_wrapper(*args, **kwargs): | ||
if not aten_op_executor_wrapper.has_run: | ||
with aten_op_executor_wrapper.lock: | ||
if not aten_op_executor_wrapper.has_run: | ||
aten_op_executor_wrapper.has_run = True | ||
return f(*args, **kwargs) | ||
|
||
aten_op_executor_wrapper.lock = threading.Lock() | ||
aten_op_executor_wrapper.has_run = False | ||
return aten_op_executor_wrapper | ||
|
||
@run_once_aten_op_executor | ||
def _load_aten_op_executor_cpp_extension(verbosity): | ||
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor | ||
C.register_aten_op_executor(str(aten_op_executor.execute_aten_operator_address())) | ||
|
||
def _load_aten_op_executor_cpp_extension_if_needed(onnx_model, verbosity): | ||
for node in onnx_model.graph.node: | ||
if node.op_type == 'ATenOp' and node.domain == 'com.microsoft': | ||
_load_aten_op_executor_cpp_extension(verbosity) | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.