Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docsrc/getting_started/capture_and_replay.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Introduction
============

This toolchain captures TensorRT network creation and build parameters at runtime via a shim, then deterministically replays them to reproduce an engine build. Use it to debug or reproduce builds independent of the originating framework.

Prerequisites
-------------

- TensorRT installed (ensure you know the absolute path to its ``lib`` and ``bin`` directories)
- ``libtensorrt_shim.so`` available in your TensorRT ``lib`` directory
- ``tensorrt_player`` available in your TensorRT ``bin`` directory

Quick start: Capture
--------------------

.. code-block:: bash

TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py

You should see ``shim.json`` and ``shim.bin`` generated in ``/tmp/torch_tensorrt_{current_user}/shim``.

Replay: Build the engine from the capture
-----------------------------------------

Use ``tensorrt_player`` to replay the captured build without the original framework:

.. code-block:: bash

tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine

This produces a serialized TensorRT engine at ``output_engine``.

Validate the engine
-------------------

Run the engine with ``trtexec``:

.. code-block:: bash

trtexec --loadEngine=/absolute/path/to/output_engine

Notes
-----

- Ensure the ``libnvinfer.so`` used by the shim matches the TensorRT version in your environment.
- If multiple TensorRT versions are installed, prefer absolute paths as shown above.
- Currently, it is not supported to capture multiple engines, in case of graph break, only the first engine will be captured.


1 change: 1 addition & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Getting Started
getting_started/jetpack
getting_started/quick_start
getting_started/tensorrt_rtx
getting_started/capture_and_replay

User Guide
------------
Expand Down
69 changes: 69 additions & 0 deletions py/torch_tensorrt/_TensorRTProxyModule.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import ctypes
import importlib
import importlib.util
import logging
import os
import platform
import pwd
import sys
import tempfile
from types import ModuleType
from typing import Any, Dict, List

_LOGGER = logging.getLogger(__name__)
package_imported = False
package_name = ""

Expand All @@ -28,6 +32,66 @@ def _find_lib(name: str, paths: List[str]) -> str:
raise FileNotFoundError(f"Could not find {name}\n Search paths: {paths}")


def enable_capture_tensorrt_api_recording() -> None:

os_env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
if os_env_flag is None or (os_env_flag != "1" and os_env_flag.lower() != "true"):
_LOGGER.debug("Capturing TensorRT API calls is not enabled")
return
if not sys.platform.startswith("linux"):
_LOGGER.warning(
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
)
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
return

linux_lib_path = []
if "LD_LIBRARY_PATH" in os.environ:
linux_lib_path.extend(os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep))

if platform.uname().processor == "x86_64":
linux_lib_path.append("/usr/lib/x86_64-linux-gnu")
elif platform.uname().processor == "aarch64":
linux_lib_path.append("/usr/lib/aarch64-linux-gnu")

for path in linux_lib_path:
if os.path.isfile(os.path.join(path, "libtensorrt_shim.so")):
try:
ctypes.CDLL(
os.path.join(path, "libtensorrt_shim.so"), mode=ctypes.RTLD_GLOBAL
)
tensorrt_lib_path = path
break
except Exception as e:
continue

if tensorrt_lib_path is None:
_LOGGER.warning(
"Capturing TensorRT API calls is enabled, but libtensorrt_shim.so is not found, make sure TensorRT lib is in the LD_LIBRARY_PATH, therefore ignoring the capture_tensorrt_api_recording setting"
)
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
else:
os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join(
tensorrt_lib_path, "libnvinfer.so"
)
current_user = pwd.getpwuid(os.getuid())[0]
shim_temp_dir = os.path.join(
tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim"
)
os.makedirs(shim_temp_dir, exist_ok=True)
json_file_name = os.path.join(shim_temp_dir, "shim.json")
os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = json_file_name
bin_file_name = os.path.join(shim_temp_dir, "shim.bin")
# if exists, delete the file, so that we can capture the new one
if os.path.exists(json_file_name):
os.remove(json_file_name)
if os.path.exists(bin_file_name):
os.remove(bin_file_name)
_LOGGER.info(
f"Capturing TensorRT API calls feature is enabled and the captured output is in the {shim_temp_dir} directory"
)


# TensorRTProxyModule is a proxy module that allows us to register the tensorrt or tensorrt-rtx package
# since tensorrt-rtx is the drop-in replacement for tensorrt, we can use the same interface to use tensorrt-rtx
class TensorRTProxyModule(ModuleType):
Expand Down Expand Up @@ -86,6 +150,11 @@ def alias_tensorrt() -> None:
if use_rtx_env_var.lower() == "true":
use_rtx = True
package_name = "tensorrt_rtx" if use_rtx else "tensorrt"

if not use_rtx:
# enable capture tensorrt api recording has to be done before importing the tensorrt library
enable_capture_tensorrt_api_recording()

# Import the appropriate package
try:
target_module = importlib.import_module(package_name)
Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/debug/_Debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import logging
import os
import sys
import tempfile
from logging.config import dictConfig
from typing import Any, List, Optional
Expand Down Expand Up @@ -32,6 +33,7 @@ def __init__(
capture_fx_graph_before: Optional[List[str]] = None,
capture_fx_graph_after: Optional[List[str]] = None,
save_engine_profile: bool = False,
capture_tensorrt_api_recording: bool = False,
profile_format: str = "perfetto",
engine_builder_monitor: bool = True,
logging_dir: str = DEBUG_LOGGING_DIR,
Expand All @@ -49,6 +51,9 @@ def __init__(
after execution of a lowering pass. Defaults to None.
save_engine_profile (bool): Whether to save TensorRT engine profiling information.
Defaults to False.
capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory.
It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose.
Defaults to False.
profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'.
If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime.
If you need to generate cudagraph visualization, set it to 'cudagraph'.
Expand All @@ -65,6 +70,7 @@ def __init__(
self.cfg = DebuggerConfig(
log_level=log_level,
save_engine_profile=save_engine_profile,
capture_tensorrt_api_recording=capture_tensorrt_api_recording,
engine_builder_monitor=engine_builder_monitor,
logging_dir=logging_dir,
profile_format=profile_format,
Expand Down Expand Up @@ -92,6 +98,23 @@ def __init__(
self.capture_fx_graph_before = capture_fx_graph_before
self.capture_fx_graph_after = capture_fx_graph_after

if self.cfg.capture_tensorrt_api_recording:
if not sys.platform.startswith("linux"):
_LOGGER.warning(
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
)
elif ENABLED_FEATURES.tensorrt_rtx:
_LOGGER.warning(
"Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring the capture_tensorrt_api_recording setting"
)
else:
env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
if env_flag is None or (env_flag != "1" and env_flag.lower() != "true"):
_LOGGER.warning(
"In order to capture TensorRT API calls, please invoke the script with environment variable TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1"
)
_LOGGER.info("Capturing TensorRT API calls feature is enabled")

def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
if ENABLED_FEATURES.torch_tensorrt_runtime:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class DebuggerConfig:
log_level: str = "debug"
save_engine_profile: bool = False
capture_tensorrt_api_recording: bool = False
engine_builder_monitor: bool = True
logging_dir: str = DEBUG_LOGGING_DIR
profile_format: str = "perfetto"
Expand Down
Loading