|
3 | 3 | import functools |
4 | 4 | import logging |
5 | 5 | import os |
| 6 | +import subprocess |
| 7 | +import sys |
6 | 8 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload |
7 | 9 |
|
8 | 10 | import numpy as np |
|
12 | 14 | from torch.fx.node import Argument, Target |
13 | 15 | from torch.fx.passes.shape_prop import TensorMetadata |
14 | 16 | from torch_tensorrt import _enums |
| 17 | +from torch_tensorrt._enums import Platform |
15 | 18 | from torch_tensorrt.dynamo._settings import CompilationSettings |
16 | 19 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
17 | 20 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
@@ -930,57 +933,91 @@ def load_tensorrt_llm() -> bool: |
930 | 933 | Returns: |
931 | 934 | bool: True if the plugin was successfully loaded and initialized, False otherwise. |
932 | 935 | """ |
933 | | - try: |
934 | | - import tensorrt_llm as trt_llm # noqa: F401 |
935 | 936 |
|
936 | | - _LOGGER.info("TensorRT-LLM successfully imported") |
937 | | - return True |
938 | | - except (ImportError, AssertionError) as e_import_error: |
939 | | - # Check for environment variable for the plugin library path |
940 | | - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
941 | | - if not plugin_lib_path: |
| 937 | + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
| 938 | + if not plugin_lib_path: |
| 939 | + _LOGGER.warning( |
| 940 | + "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library", |
| 941 | + ) |
| 942 | + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( |
| 943 | + "1", |
| 944 | + "true", |
| 945 | + "yes", |
| 946 | + "on", |
| 947 | + ) |
| 948 | + if not use_trtllm_plugin: |
942 | 949 | _LOGGER.warning( |
943 | | - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", |
| 950 | + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library" |
944 | 951 | ) |
945 | 952 | return False |
946 | | - |
947 | | - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") |
948 | | - try: |
949 | | - # Load the shared library |
950 | | - handle = ctypes.CDLL(plugin_lib_path) |
951 | | - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
952 | | - except OSError as e_os_error: |
953 | | - _LOGGER.error( |
954 | | - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" |
955 | | - f"Ensure the path is correct and the library is compatible", |
956 | | - exc_info=e_os_error, |
| 953 | + else: |
| 954 | + py_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
| 955 | + platform = Platform.current_platform() |
| 956 | + if Platform == Platform.LINUX_X86_64: |
| 957 | + platform = "linux_x86_64" |
| 958 | + elif Platform == Platform.LINUX_AARCH64: |
| 959 | + platform = "linux_aarch64" |
| 960 | + |
| 961 | + if py_version not in ("cp310", "cp312"): |
| 962 | + _LOGGER.warning( |
| 963 | + "No available wheel for python versions other than py3.10 and py3.12" |
| 964 | + ) |
| 965 | + if py_version == "cp310" and platform == "linux_aarch64": |
| 966 | + _LOGGER.warning("No available wheel for python3.10 with Linux aarch64") |
| 967 | + |
| 968 | + base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
| 969 | + file_name = ( |
| 970 | + "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl" |
957 | 971 | ) |
958 | | - return False |
| 972 | + download_url = base_url + file_name |
| 973 | + cmd = ["wget", download_url] |
| 974 | + subprocess.run(cmd) |
| 975 | + if os.path.exists(file_name): |
| 976 | + _LOGGER.info("filename download is completed") |
| 977 | + import zipfile |
| 978 | + |
| 979 | + with zipfile.ZipFile(file_name, "r") as zip_ref: |
| 980 | + zip_ref.extractall( |
| 981 | + "./tensorrt_llm" |
| 982 | + ) # Extract to a folder named 'tensorrt_llm' |
| 983 | + plugin_lib_path = ( |
| 984 | + "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so" |
| 985 | + ) |
| 986 | + try: |
| 987 | + # Load the shared library |
| 988 | + handle = ctypes.CDLL(plugin_lib_path) |
| 989 | + _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
| 990 | + except OSError as e_os_error: |
| 991 | + _LOGGER.error( |
| 992 | + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" |
| 993 | + f"Ensure the path is correct and the library is compatible", |
| 994 | + exc_info=e_os_error, |
| 995 | + ) |
| 996 | + return False |
959 | 997 |
|
960 | | - try: |
961 | | - # Configure plugin initialization arguments |
962 | | - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
963 | | - handle.initTrtLlmPlugins.restype = ctypes.c_bool |
964 | | - except AttributeError as e_plugin_unavailable: |
965 | | - _LOGGER.warning( |
966 | | - "Unable to initialize the TensorRT-LLM plugin library", |
967 | | - exc_info=e_plugin_unavailable, |
968 | | - ) |
969 | | - return False |
| 998 | + try: |
| 999 | + # Configure plugin initialization arguments |
| 1000 | + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
| 1001 | + handle.initTrtLlmPlugins.restype = ctypes.c_bool |
| 1002 | + except AttributeError as e_plugin_unavailable: |
| 1003 | + _LOGGER.warning( |
| 1004 | + "Unable to initialize the TensorRT-LLM plugin library", |
| 1005 | + exc_info=e_plugin_unavailable, |
| 1006 | + ) |
| 1007 | + return False |
970 | 1008 |
|
971 | | - try: |
972 | | - # Initialize the plugin |
973 | | - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" |
974 | | - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): |
975 | | - _LOGGER.info("TensorRT-LLM plugin successfully initialized") |
976 | | - return True |
977 | | - else: |
978 | | - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") |
979 | | - return False |
980 | | - except Exception as e_initialization_error: |
981 | | - _LOGGER.warning( |
982 | | - "Exception occurred during TensorRT-LLM plugin library initialization", |
983 | | - exc_info=e_initialization_error, |
984 | | - ) |
| 1009 | + try: |
| 1010 | + # Initialize the plugin |
| 1011 | + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" |
| 1012 | + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): |
| 1013 | + _LOGGER.info("TensorRT-LLM plugin successfully initialized") |
| 1014 | + return True |
| 1015 | + else: |
| 1016 | + _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") |
985 | 1017 | return False |
986 | | - return False |
| 1018 | + except Exception as e_initialization_error: |
| 1019 | + _LOGGER.warning( |
| 1020 | + "Exception occurred during TensorRT-LLM plugin library initialization", |
| 1021 | + exc_info=e_initialization_error, |
| 1022 | + ) |
| 1023 | + return False |
0 commit comments