diff --git a/docsrc/getting_started/capture_and_replay.rst b/docsrc/getting_started/capture_and_replay.rst new file mode 100644 index 0000000000..e04e1899c7 --- /dev/null +++ b/docsrc/getting_started/capture_and_replay.rst @@ -0,0 +1,49 @@ +Introduction +============ + +This toolchain captures TensorRT network creation and build parameters at runtime via a shim, then deterministically replays them to reproduce an engine build. Use it to debug or reproduce builds independent of the originating framework. + +Prerequisites +------------- + +- TensorRT installed (ensure you know the absolute path to its ``lib`` and ``bin`` directories) +- ``libtensorrt_shim.so`` available in your TensorRT ``lib`` directory +- ``tensorrt_player`` available in your TensorRT ``bin`` directory + +Quick start: Capture +-------------------- + +.. code-block:: bash + + TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py + +You should see ``shim.json`` and ``shim.bin`` generated in ``/tmp/torch_tensorrt_{current_user}/shim``. + +Replay: Build the engine from the capture +----------------------------------------- + +Use ``tensorrt_player`` to replay the captured build without the original framework: + +.. code-block:: bash + + tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine + +This produces a serialized TensorRT engine at ``output_engine``. + +Validate the engine +------------------- + +Run the engine with ``trtexec``: + +.. code-block:: bash + + trtexec --loadEngine=/absolute/path/to/output_engine + +Notes +----- + +- Ensure the ``libnvinfer.so`` used by the shim matches the TensorRT version in your environment. +- If multiple TensorRT versions are installed, prefer absolute paths as shown above. +- Currently, it is not supported to capture multiple engines, in case of graph break, only the first engine will be captured. + + diff --git a/docsrc/index.rst b/docsrc/index.rst index 68e1ba5259..671379d004 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -29,6 +29,7 @@ Getting Started getting_started/jetpack getting_started/quick_start getting_started/tensorrt_rtx + getting_started/capture_and_replay User Guide ------------ diff --git a/py/torch_tensorrt/_TensorRTProxyModule.py b/py/torch_tensorrt/_TensorRTProxyModule.py index c5917a3ae0..84d2c50f5f 100644 --- a/py/torch_tensorrt/_TensorRTProxyModule.py +++ b/py/torch_tensorrt/_TensorRTProxyModule.py @@ -1,12 +1,16 @@ import ctypes import importlib import importlib.util +import logging import os import platform +import pwd import sys +import tempfile from types import ModuleType from typing import Any, Dict, List +_LOGGER = logging.getLogger(__name__) package_imported = False package_name = "" @@ -28,6 +32,66 @@ def _find_lib(name: str, paths: List[str]) -> str: raise FileNotFoundError(f"Could not find {name}\n Search paths: {paths}") +def enable_capture_tensorrt_api_recording() -> None: + + os_env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None) + if os_env_flag is None or (os_env_flag != "1" and os_env_flag.lower() != "true"): + _LOGGER.debug("Capturing TensorRT API calls is not enabled") + return + if not sys.platform.startswith("linux"): + _LOGGER.warning( + f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}" + ) + os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE") + return + + linux_lib_path = [] + if "LD_LIBRARY_PATH" in os.environ: + linux_lib_path.extend(os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep)) + + if platform.uname().processor == "x86_64": + linux_lib_path.append("/usr/lib/x86_64-linux-gnu") + elif platform.uname().processor == "aarch64": + linux_lib_path.append("/usr/lib/aarch64-linux-gnu") + + for path in linux_lib_path: + if os.path.isfile(os.path.join(path, "libtensorrt_shim.so")): + try: + ctypes.CDLL( + os.path.join(path, "libtensorrt_shim.so"), mode=ctypes.RTLD_GLOBAL + ) + tensorrt_lib_path = path + break + except Exception as e: + continue + + if tensorrt_lib_path is None: + _LOGGER.warning( + "Capturing TensorRT API calls is enabled, but libtensorrt_shim.so is not found, make sure TensorRT lib is in the LD_LIBRARY_PATH, therefore ignoring the capture_tensorrt_api_recording setting" + ) + os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE") + else: + os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join( + tensorrt_lib_path, "libnvinfer.so" + ) + current_user = pwd.getpwuid(os.getuid())[0] + shim_temp_dir = os.path.join( + tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim" + ) + os.makedirs(shim_temp_dir, exist_ok=True) + json_file_name = os.path.join(shim_temp_dir, "shim.json") + os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = json_file_name + bin_file_name = os.path.join(shim_temp_dir, "shim.bin") + # if exists, delete the file, so that we can capture the new one + if os.path.exists(json_file_name): + os.remove(json_file_name) + if os.path.exists(bin_file_name): + os.remove(bin_file_name) + _LOGGER.info( + f"Capturing TensorRT API calls feature is enabled and the captured output is in the {shim_temp_dir} directory" + ) + + # TensorRTProxyModule is a proxy module that allows us to register the tensorrt or tensorrt-rtx package # since tensorrt-rtx is the drop-in replacement for tensorrt, we can use the same interface to use tensorrt-rtx class TensorRTProxyModule(ModuleType): @@ -86,6 +150,11 @@ def alias_tensorrt() -> None: if use_rtx_env_var.lower() == "true": use_rtx = True package_name = "tensorrt_rtx" if use_rtx else "tensorrt" + + if not use_rtx: + # enable capture tensorrt api recording has to be done before importing the tensorrt library + enable_capture_tensorrt_api_recording() + # Import the appropriate package try: target_module = importlib.import_module(package_name) diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index ec624ffc5a..39e4217f73 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -2,6 +2,7 @@ import functools import logging import os +import sys import tempfile from logging.config import dictConfig from typing import Any, List, Optional @@ -32,6 +33,7 @@ def __init__( capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, + capture_tensorrt_api_recording: bool = False, profile_format: str = "perfetto", engine_builder_monitor: bool = True, logging_dir: str = DEBUG_LOGGING_DIR, @@ -49,6 +51,9 @@ def __init__( after execution of a lowering pass. Defaults to None. save_engine_profile (bool): Whether to save TensorRT engine profiling information. Defaults to False. + capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory. + It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose. + Defaults to False. profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'. If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime. If you need to generate cudagraph visualization, set it to 'cudagraph'. @@ -65,6 +70,7 @@ def __init__( self.cfg = DebuggerConfig( log_level=log_level, save_engine_profile=save_engine_profile, + capture_tensorrt_api_recording=capture_tensorrt_api_recording, engine_builder_monitor=engine_builder_monitor, logging_dir=logging_dir, profile_format=profile_format, @@ -92,6 +98,23 @@ def __init__( self.capture_fx_graph_before = capture_fx_graph_before self.capture_fx_graph_after = capture_fx_graph_after + if self.cfg.capture_tensorrt_api_recording: + if not sys.platform.startswith("linux"): + _LOGGER.warning( + f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}" + ) + elif ENABLED_FEATURES.tensorrt_rtx: + _LOGGER.warning( + "Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring the capture_tensorrt_api_recording setting" + ) + else: + env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None) + if env_flag is None or (env_flag != "1" and env_flag.lower() != "true"): + _LOGGER.warning( + "In order to capture TensorRT API calls, please invoke the script with environment variable TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1" + ) + _LOGGER.info("Capturing TensorRT API calls feature is enabled") + def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() if ENABLED_FEATURES.torch_tensorrt_runtime: diff --git a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py index 27a5025e8b..82cd3ba83a 100644 --- a/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py +++ b/py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py @@ -7,6 +7,7 @@ class DebuggerConfig: log_level: str = "debug" save_engine_profile: bool = False + capture_tensorrt_api_recording: bool = False engine_builder_monitor: bool = True logging_dir: str = DEBUG_LOGGING_DIR profile_format: str = "perfetto"