diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 5fffb3fa00..cd65558c6b 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -3,7 +3,7 @@ Tensor Parallel Initialize Distributed Environment ================================================== -This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed. """ import logging @@ -19,8 +19,65 @@ logger = logging.getLogger(__name__) -# this is kept at the application level, when mpirun is used to run the application -def initialize_distributed_env(rank=0, world_size=1, port=29500): +def initialize_logger( + rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO +): + """Initialize rank-specific Torch-TensorRT logger with configurable handler levels. + + Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers + + Args: + rank: Process rank for multi-GPU + logger_file_name: Base name for log file (will add _rank.log) + file_level: What goes to file - default DEBUG (everything) + console_level: What prints to console - default INFO (clean output) + """ + logger = logging.getLogger("torch_tensorrt") + logger.setLevel(logging.DEBUG) + logger.handlers.clear() + + # File handler + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(file_level) + fh.setFormatter( + logging.Formatter( + f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) + logger.addHandler(fh) + + # console handler + ch = logging.StreamHandler() + ch.setLevel( + console_level + ) # Console handler controls what's printed in console output + ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s")) + logger.addHandler(ch) + + # safegauard though not reqd + logger.propagate = False + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env( + logger_file_name, + rank=0, + world_size=1, + port=29500, + file_level="debug", + console_level="info", +): + """Initialize distributed environment with handler-based logging. + + Args: + logger_file_name: Base name for log files + rank: Initial rank (overridden by OMPI env vars) + world_size: Initial world size (overridden by OMPI env vars) + port: Master port for distributed communication + file_level: File handler level - "debug", "info", "warning" (default: "debug") + console_level: Console handler level - "debug", "info", "warning" (default: "info") + """ local_rank = int( os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) ) @@ -44,12 +101,40 @@ def initialize_distributed_env(rank=0, world_size=1, port=29500): device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) rank = device_mesh.get_rank() assert rank == local_rank + # Convert string handler levels to logging constants + level_map = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + } + file_level_int = level_map.get(file_level.lower(), logging.DEBUG) + console_level_int = level_map.get(console_level.lower(), logging.INFO) + + # Initialize logger with handler-specific levels + # Logger itself is always DEBUG - handlers do the filtering + logger = initialize_logger( + rank, + logger_file_name, + file_level=file_level_int, + console_level=console_level_int, + ) device_id = ( rank % torch.cuda.device_count() ) # Ensure each rank gets a unique device torch.cuda.set_device(device_id) - return device_mesh, world_size, rank + # Set C++ TensorRT runtime log level based on most verbose handler + # Use the most verbose level to ensure all important logs are captured + cpp_level = min(file_level_int, console_level_int) + try: + import torch_tensorrt.logging as torchtrt_logging + + torchtrt_logging.set_level(cpp_level) + except Exception as e: + logger.warning(f"Could not set C++ TensorRT log level: {e}") + + return device_mesh, world_size, rank, logger def cleanup_distributed_env(): diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 03cf4256ec..6fd7db5551 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -74,10 +74,29 @@ def _enabled_features_str() -> str: enabled = lambda x: "ENABLED" if x else "DISABLED" - out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call] + out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call] return out_str +# Inline helper functions for checking feature availability +def has_torch_tensorrt_runtime() -> bool: + """Check if Torch-TensorRT C++ runtime is available. + + Returns: + bool: True if libtorchtrt_runtime.so or libtorchtrt.so is available + """ + return bool(ENABLED_FEATURES.torch_tensorrt_runtime) + + +def has_torchscript_frontend() -> bool: + """Check if TorchScript frontend is available. + + Returns: + bool: True if libtorchtrt.so is available + """ + return bool(ENABLED_FEATURES.torchscript_frontend) + + def needs_tensorrt_rtx(f: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.tensorrt_rtx: @@ -163,6 +182,19 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Runtime check decorator for TensorRT-LLM NCCL plugin availability. + + WARNING: This decorator CANNOT prevent registration of converters at import time. + When used with @dynamo_tensorrt_converter, the converter is always registered + regardless of decorator order, because registration happens at import time before + the wrapper is called. + + This decorator is kept for potential non-registration use cases where + runtime checks are appropriate. + @apbose: to discuss if this is required + """ + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: if ENABLED_FEATURES.trtllm_for_nccl: return f(*args, **kwargs) @@ -170,7 +202,7 @@ def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: raise NotImplementedError( - "Refit feature is currently not available in Python 3.13 or higher" + "TensorRT-LLM plugin for NCCL is not available" ) return not_implemented(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b834d8087f..a229036e27 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -7,6 +7,7 @@ import torch import torch._dynamo as td +import torch_tensorrt.logging as torchtrt_logging from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple @@ -23,7 +24,6 @@ from torch_tensorrt.dynamo.utils import ( parse_dynamo_kwargs, prepare_inputs, - set_log_level, ) logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def torch_tensorrt_backend( and "debug" in kwargs["options"] and kwargs["options"]["debug"] ) or ("debug" in kwargs and kwargs["debug"]): - set_log_level(logger.parent, logging.DEBUG) + torchtrt_logging.set_level(logging.DEBUG) DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index db14e3528b..302a254f60 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -5,7 +5,7 @@ import tensorrt as trt from torch.fx.node import Argument, Target -from torch_tensorrt._features import needs_trtllm_for_nccl +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -20,37 +20,53 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) -def fused_nccl_gather( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_gather( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +# Conditionally register NCCL converters only if TensorRT-LLM plugin is available. +# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because +# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator +# order. Conditional registration prevents registration when TRTLLM is unavailable, +# allowing fallback to PyTorch execution for NCCL ops. + +# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator +# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch +if ENABLED_FEATURES.trtllm_for_nccl: + _LOGGER.debug( + "TensorRT-LLM plugin for NCCL is available. Registering NCCL converters." ) + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_gather( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.nccl_ops.nccl_reduce_scatter( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) -@needs_trtllm_for_nccl -@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) -def fused_nccl_reduce_scatter( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.nccl_ops.nccl_reduce_scatter( - ctx, - target, - SourceIR.ATEN, - name, - [args[0]], +else: + _LOGGER.info( + "TensorRT-LLM plugin for NCCL is not available. " + "NCCL operations will fall back to PyTorch execution." ) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index abc697a086..5ddb814e1e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -28,7 +28,6 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype -from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt._utils import is_tensorrt_version_supported from torch_tensorrt.dynamo import _defaults @@ -270,33 +269,6 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device: return device -def set_log_level(parent_logger: Any, level: Any) -> None: - """ - Sets the log level to the user provided level. - This is used to set debug logging at a global level - at entry points of tracing, dynamo and torch_compile compilation. - And set log level for c++ torch trt logger if runtime is available. - """ - if parent_logger: - parent_logger.setLevel(level) - - if ENABLED_FEATURES.torch_tensorrt_runtime: - if level == logging.DEBUG: - log_level = trt.ILogger.Severity.VERBOSE - elif level == logging.INFO: - log_level = trt.ILogger.Severity.INFO - elif level == logging.WARNING: - log_level = trt.ILogger.Severity.WARNING - elif level == logging.ERROR: - log_level = trt.ILogger.Severity.ERROR - elif level == logging.CRITICAL: - log_level = trt.ILogger.Severity.INTERNAL_ERROR - else: - raise AssertionError(f"{level} is not valid log level") - - torch.ops.tensorrt.set_logging_level(int(log_level)) - - def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], disable_memory_format_check: bool = False, diff --git a/py/torch_tensorrt/logging.py b/py/torch_tensorrt/logging.py index 0cba3bd510..197a2e823b 100644 --- a/py/torch_tensorrt/logging.py +++ b/py/torch_tensorrt/logging.py @@ -3,7 +3,10 @@ import tensorrt as trt import torch -from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt._features import ( + has_torch_tensorrt_runtime, + has_torchscript_frontend, +) logging.captureWarnings(True) _LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]") @@ -31,6 +34,81 @@ def log(self, severity: trt.ILogger.Severity, msg: str) -> None: TRT_LOGGER = _TRTLogger() +def set_level(level: int, logger: Any = None) -> None: + """Set log level for both Python and C++ torch_tensorrt loggers. + + Permanently sets the log level until changed again or process exits. + Automatically handles runtime availability checks. + + This sets the log level for: + - Specified Python logger (or root torch_tensorrt logger if None) + - TorchScript frontend C++ logger (if available) + - Dynamo runtime C++ logger (if available) + + Args: + level: Python logging level (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL) + logger: Optional logger to set level for. If None, sets the root torch_tensorrt logger. + + Example: + + # Set debug logging for entire session + torch_tensorrt.logging.set_level(logging.DEBUG) + + # Or set for a specific logger + my_logger = logging.getLogger("torch_tensorrt.dynamo") + torch_tensorrt.logging.set_level(logging.DEBUG, logger=my_logger) + """ + # Set the specified logger or default to root torch_tensorrt logger + if logger is None: + logging.getLogger("torch_tensorrt").setLevel(level) + _LOGGER.setLevel(level) + else: + logger.setLevel(level) + + if has_torchscript_frontend(): + from torch_tensorrt.ts import logging as ts_logging + + if level == logging.CRITICAL: + ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) + elif level == logging.ERROR: + ts_logging.set_reportable_log_level(ts_logging.Level.Error) + elif level == logging.WARNING: + ts_logging.set_reportable_log_level(ts_logging.Level.Warning) + elif level == logging.INFO: + ts_logging.set_reportable_log_level(ts_logging.Level.Info) + elif level == logging.DEBUG: + ts_logging.set_reportable_log_level(ts_logging.Level.Debug) + elif level == logging.NOTSET: + ts_logging.set_reportable_log_level(ts_logging.Level.Graph) + else: + raise ValueError( + f"Invalid log level: {level}. Must be one of: " + f"logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL, logging.NOTSET" + ) + + elif has_torch_tensorrt_runtime(): + if level == logging.CRITICAL: + torch.ops.tensorrt.set_logging_level( + int(trt.ILogger.Severity.INTERNAL_ERROR) + ) + elif level == logging.ERROR: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) + elif level == logging.WARNING: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) + elif level == logging.INFO: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) + elif level == logging.DEBUG: + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) + elif level == logging.NOTSET: + # Graph level (most verbose) + torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1) + else: + raise ValueError( + f"Invalid log level: {level}. Must be one of: " + f"logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL, logging.NOTSET" + ) + + class internal_errors: """Context-manager to limit displayed log messages to just internal errors @@ -46,13 +124,13 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.CRITICAL) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.InternalError) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level( int(trt.ILogger.Severity.INTERNAL_ERROR) @@ -61,12 +139,12 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -85,25 +163,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.ERROR) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Error) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -122,25 +200,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.WARNING) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Warning) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -159,25 +237,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.INFO) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Info) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -196,25 +274,25 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.DEBUG) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Debug) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE)) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) @@ -234,23 +312,23 @@ def __enter__(self) -> None: self.external_lvl = _LOGGER.getEffectiveLevel() _LOGGER.setLevel(logging.NOTSET) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging self.ts_level = ts_logging.get_reportable_log_level() ts_logging.set_reportable_log_level(ts_logging.Level.Graph) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): self.rt_level = torch.ops.tensorrt.get_logging_level() torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE) + 1) def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: _LOGGER.setLevel(self.external_lvl) - if ENABLED_FEATURES.torchscript_frontend: + if has_torchscript_frontend(): from torch_tensorrt.ts import logging as ts_logging ts_logging.set_reportable_log_level(self.ts_level) - elif ENABLED_FEATURES.torch_tensorrt_runtime: + elif has_torch_tensorrt_runtime(): torch.ops.tensorrt.set_logging_level(self.rt_level) diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py index 91476fdc63..27f6aa624f 100644 --- a/tests/py/dynamo/distributed/test_nccl_ops.py +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -26,7 +26,26 @@ ) from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt._utils import is_platform_supported_for_trtllm +from torch_tensorrt._features import ENABLED_FEATURES + + +def is_distributed_nccl_available(): + """ + Check if torch.distributed with NCCL backend is available. + + Note: torch.distributed is available on Windows but NCCL backend is not. + NCCL (NVIDIA Collective Communications Library) is Linux/Unix only. + This function returns False on Windows, Jetson, and other platforms + where NCCL backend is not supported. + """ + try: + import torch.distributed as dist + + # Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available + return dist.is_nccl_available() + except (ImportError, AttributeError): + return False + if "OMPI_COMM_WORLD_SIZE" in os.environ: set_environment_variables_pytest_multi_process() @@ -71,9 +90,15 @@ def forward(self, x): class TestNcclOpsConverter(DispatchTestCase): + # 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement + # 2. Skip if TRTLLM is unavailable (e.g., CUDA 13) - no converters registered + @unittest.skipIf( + not is_distributed_nccl_available(), + "Skipped: NCCL backend is not available (Windows/Jetson Orin not supported).", + ) @unittest.skipIf( - not is_platform_supported_for_trtllm(), - "Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.", + not ENABLED_FEATURES.trtllm_for_nccl, + "Skipped: TensorRT-LLM plugin for NCCL is not available (e.g., CUDA 13).", ) @classmethod def setUpClass(cls):