Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import logging
import os
import platform
import pwd
import sys
import tempfile
from typing import Dict, List

import torch
Expand All @@ -24,6 +26,69 @@

import torch


def is_capture_tensorrt_api_recording_enabled() -> bool:
if os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE") != "1":
return False
if not sys.platform.startswith("linux"):
_LOGGER.warning(
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
)
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
return False
if os.environ.get("USE_TRT_RTX", "False").lower() == "true":
_LOGGER.warning(
"Capturing TensorRT API calls is only supported on TensorRT, therefore ignoring the capture_tensorrt_api_recording setting for TensorRT-RTX"
)
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
return False
return True


if is_capture_tensorrt_api_recording_enabled():
linux_lib_path = []
if "LD_LIBRARY_PATH" in os.environ:
linux_lib_path.extend(os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep))

if platform.uname().processor == "x86_64":
linux_lib_path.append("/usr/lib/x86_64-linux-gnu")
elif platform.uname().processor == "aarch64":
linux_lib_path.append("/usr/lib/aarch64-linux-gnu")

tensorrt_lib_path = None
for path in linux_lib_path:
try:
ctypes.CDLL(
os.path.join(path, "libtensorrt_shim.so"), mode=ctypes.RTLD_GLOBAL
)
tensorrt_lib_path = path
break
except Exception as e:
continue

if tensorrt_lib_path is None:
_LOGGER.error(
"Capturing TensorRT API calls is enabled, but libtensorrt_shim.so is not found, make sure TensorRT lib is in the LD_LIBRARY_PATH, therefore ignoring the capture_tensorrt_api_recording setting"
)
os.environ.pop("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE")
else:
os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join(
tensorrt_lib_path, "libnvinfer.so"
)
current_user = pwd.getpwuid(os.getuid())[0]
shim_temp_dir = os.path.join(
tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim"
)
os.makedirs(shim_temp_dir, exist_ok=True)
os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = os.path.join(
shim_temp_dir, "shim.json"
)
_LOGGER.debug(
f"capture_shim feature is enabled and the captured output is in the {shim_temp_dir} directory"
)
else:
_LOGGER.info("capture_shim feature is disabled")

tensorrt_package_name = ""

try:
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/debug/_Debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
capture_fx_graph_before: Optional[List[str]] = None,
capture_fx_graph_after: Optional[List[str]] = None,
save_engine_profile: bool = False,
capture_tensorrt_api_recording: bool = False,
profile_format: str = "perfetto",
engine_builder_monitor: bool = True,
logging_dir: str = DEBUG_LOGGING_DIR,
Expand All @@ -49,6 +50,9 @@ def __init__(
after execution of a lowering pass. Defaults to None.
save_engine_profile (bool): Whether to save TensorRT engine profiling information.
Defaults to False.
capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory.
It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose.
Defaults to False.
profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'.
If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime.
If you need to generate cudagraph visualization, set it to 'cudagraph'.
Expand All @@ -65,6 +69,7 @@ def __init__(
self.cfg = DebuggerConfig(
log_level=log_level,
save_engine_profile=save_engine_profile,
capture_tensorrt_api_recording=capture_tensorrt_api_recording,
engine_builder_monitor=engine_builder_monitor,
logging_dir=logging_dir,
profile_format=profile_format,
Expand All @@ -91,6 +96,10 @@ def __init__(

self.capture_fx_graph_before = capture_fx_graph_before
self.capture_fx_graph_after = capture_fx_graph_after
if os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE") == "1":
self.cfg.capture_tensorrt_api_recording = True
else:
self.cfg.capture_tensorrt_api_recording = False

def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/debug/_DebuggerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class DebuggerConfig:
log_level: str = "debug"
save_engine_profile: bool = False
capture_tensorrt_api_recording: bool = False
engine_builder_monitor: bool = True
logging_dir: str = DEBUG_LOGGING_DIR
profile_format: str = "perfetto"
Expand Down
37 changes: 37 additions & 0 deletions tools/debug/capture_replay/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
## Introduction

This toolchain captures TensorRT network creation and build parameters at runtime via a shim, then deterministically replays them to reproduce an engine build. Use it to debug or reproduce builds independent of the originating framework.

### Prerequisites
- TensorRT installed (ensure you know the absolute path to its `lib` and `bin` directories)
- `libtensorrt_shim.so` available in your TensorRT `lib` directory
- `tensorrt_player` available in your TensorRT `bin` directory

### Quick start: Capture

```bash
TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py
```
you should be able to see the shim.json shim.bin in being generated in /tmp/torch_tensorrt_{current_user}/shim directory


### Replay: Build the engine from the capture
Use `tensorrt_player` to replay the captured build without the original framework:

```bash
tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine
```

This produces a serialized TensorRT engine at `output_engine`.

### Validate the engine
Run the engine with `trtexec`:

```bash
trtexec --loadEngine=/absolute/path/to/output_engine
```

### Notes
- Ensure the `libnvinfer.so` used by the shim matches the TensorRT version in your environment.
- If multiple TensorRT versions are installed, prefer absolute paths as shown above.
- The capture is best-effort; if your program builds multiple engines, multiple captures may be produced.
Loading