Skip to content

Commit 31666e3

Browse files
committed
Adding rank based logging for torch distributed examples. Also correcting TRT-LLM installation fallback cases
1 parent 13f9f0b commit 31666e3

File tree

4 files changed

+174
-61
lines changed

4 files changed

+174
-61
lines changed

examples/distributed_inference/tensor_parallel_initialize_dist.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Tensor Parallel Initialize Distributed Environment
44
==================================================
55
6-
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
6+
This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. These utilities are useful for tensor parallel distributed inference examples using torch.distributed.
77
"""
88

99
import logging
@@ -16,30 +16,66 @@
1616
import torch.distributed as dist
1717
from torch.distributed._tensor.device_mesh import init_device_mesh
1818

19+
logger = logging.getLogger(__name__)
1920

20-
def find_repo_root(max_depth=10):
21-
dir_path = os.path.dirname(os.path.realpath(__file__))
22-
for i in range(max_depth):
23-
files = os.listdir(dir_path)
24-
if "MODULE.bazel" in files:
25-
return dir_path
26-
else:
27-
dir_path = os.path.dirname(dir_path)
2821

29-
raise RuntimeError("Could not find repo root")
22+
def initialize_logger(
23+
rank, logger_file_name, file_level=logging.DEBUG, console_level=logging.INFO
24+
):
25+
"""Initialize rank-specific Torch-TensorRT logger with configurable handler levels.
3026
27+
Logger level is set to DEBUG (pass-through), handlers control filtering for files and stream buffers
3128
32-
def initialize_logger(rank, logger_file_name):
33-
logger = logging.getLogger()
34-
logger.setLevel(logging.INFO)
29+
Args:
30+
rank: Process rank for multi-GPU
31+
logger_file_name: Base name for log file (will add _rank.log)
32+
file_level: What goes to file - default DEBUG (everything)
33+
console_level: What prints to console - default INFO (clean output)
34+
"""
35+
logger = logging.getLogger("torch_tensorrt")
36+
logger.setLevel(logging.DEBUG)
37+
logger.handlers.clear()
38+
39+
# File handler
3540
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
36-
fh.setLevel(logging.INFO)
41+
fh.setLevel(file_level)
42+
fh.setFormatter(
43+
logging.Formatter(
44+
f"[Rank {rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
45+
)
46+
)
3747
logger.addHandler(fh)
48+
49+
# console handler
50+
ch = logging.StreamHandler()
51+
ch.setLevel(console_level) # Console handler controls what's printed
52+
ch.setFormatter(logging.Formatter(f"[Rank {rank}] %(levelname)s: %(message)s"))
53+
logger.addHandler(ch)
54+
55+
# safegauard though not reqd
56+
logger.propagate = False
3857
return logger
3958

4059

4160
# This is required for env initialization since we use mpirun
42-
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
61+
def initialize_distributed_env(
62+
logger_file_name,
63+
rank=0,
64+
world_size=1,
65+
port=29500,
66+
file_level="debug",
67+
console_level="info",
68+
):
69+
"""Initialize distributed environment with handler-based logging.
70+
71+
Args:
72+
logger_file_name: Base name for log files
73+
rank: Initial rank (overridden by OMPI env vars)
74+
world_size: Initial world size (overridden by OMPI env vars)
75+
port: Master port for distributed communication
76+
file_level: File handler level - "debug", "info", "warning" (default: "debug")
77+
console_level: Console handler level - "debug", "info", "warning" (default: "info")
78+
"""
4379
local_rank = int(
4480
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
4581
)
@@ -50,9 +86,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
5086
os.environ["WORLD_SIZE"] = str(world_size)
5187
os.environ["MASTER_ADDR"] = "127.0.0.1"
5288
os.environ["MASTER_PORT"] = str(port)
53-
os.environ["TRTLLM_PLUGINS_PATH"] = (
54-
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
55-
)
5689

5790
# Necessary to assign a device to each rank.
5891
torch.cuda.set_device(local_rank)
@@ -66,12 +99,50 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
6699
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
67100
rank = device_mesh.get_rank()
68101
assert rank == local_rank
69-
logger = initialize_logger(rank, logger_file_name)
102+
103+
# Convert string handler levels to logging constants
104+
level_map = {
105+
"debug": logging.DEBUG,
106+
"info": logging.INFO,
107+
"warning": logging.WARNING,
108+
"error": logging.ERROR,
109+
}
110+
file_level_int = level_map.get(file_level.lower(), logging.DEBUG)
111+
console_level_int = level_map.get(console_level.lower(), logging.INFO)
112+
113+
# Initialize logger with handler-specific levels
114+
# Logger itself is always DEBUG - handlers do the filtering
115+
logger = initialize_logger(
116+
rank,
117+
logger_file_name,
118+
file_level=file_level_int,
119+
console_level=console_level_int,
120+
)
121+
70122
device_id = (
71123
rank % torch.cuda.device_count()
72124
) # Ensure each rank gets a unique device
73125
torch.cuda.set_device(device_id)
74126

127+
# Set C++ TensorRT runtime log level based on most verbose handler
128+
# this is similar to set_log_level()
129+
cpp_level = min(file_level_int, console_level_int)
130+
try:
131+
import tensorrt as trt
132+
from torch_tensorrt._features import ENABLED_FEATURES
133+
134+
if ENABLED_FEATURES.torch_tensorrt_runtime:
135+
if cpp_level == logging.DEBUG:
136+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.VERBOSE))
137+
elif cpp_level == logging.INFO:
138+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.INFO))
139+
elif cpp_level == logging.WARNING:
140+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.WARNING))
141+
else:
142+
torch.ops.tensorrt.set_logging_level(int(trt.ILogger.Severity.ERROR))
143+
except Exception as e:
144+
logger.warning(f"Could not set C++ TensorRT log level: {e}")
145+
75146
return device_mesh, world_size, rank, logger
76147

77148

py/torch_tensorrt/_features.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
def _enabled_features_str() -> str:
7676
enabled = lambda x: "ENABLED" if x else "DISABLED"
77-
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n" # type: ignore[no-untyped-call]
77+
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)} \n - TensorRT-RTX: {enabled(_TENSORRT_RTX)}\n - TensorRT-LLM for NCCL: {enabled(_TRTLLM_AVAIL)}\n" # type: ignore[no-untyped-call]
7878
return out_str
7979

8080

@@ -163,17 +163,24 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
163163

164164

165165
def needs_trtllm_for_nccl(f: Callable[..., Any]) -> Callable[..., Any]:
166+
"""
167+
Runtime check decorator for TensorRT-LLM NCCL plugin availability.
168+
169+
WARNING: This decorator CANNOT prevent registration of converters at import time.
170+
When used with @dynamo_tensorrt_converter, the converter is always registered
171+
regardless of decorator order, because registration happens at import time before
172+
the wrapper is called.
173+
174+
This decorator is kept for potential non-registration use cases where
175+
runtime checks are appropriate.
176+
@apbose: to discuss if this is required
177+
"""
178+
166179
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
167180
if ENABLED_FEATURES.trtllm_for_nccl:
168181
return f(*args, **kwargs)
169182
else:
170-
171-
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
172-
raise NotImplementedError(
173-
"Refit feature is currently not available in Python 3.13 or higher"
174-
)
175-
176-
return not_implemented(*args, **kwargs)
183+
raise NotImplementedError("TensorRT-LLM plugin for NCCL is not available")
177184

178185
return wrapper
179186

py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import tensorrt as trt
77
from torch.fx.node import Argument, Target
8-
from torch_tensorrt._features import needs_trtllm_for_nccl
8+
from torch_tensorrt._features import ENABLED_FEATURES
99
from torch_tensorrt.dynamo._SourceIR import SourceIR
1010
from torch_tensorrt.dynamo.conversion import impl
1111
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -20,37 +20,53 @@
2020
_LOGGER: logging.Logger = logging.getLogger(__name__)
2121

2222

23-
@needs_trtllm_for_nccl
24-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
25-
def fused_nccl_gather(
26-
ctx: ConversionContext,
27-
target: Target,
28-
args: Tuple[Argument, ...],
29-
kwargs: Dict[str, Argument],
30-
name: str,
31-
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
32-
return impl.nccl_ops.nccl_gather(
33-
ctx,
34-
target,
35-
SourceIR.ATEN,
36-
name,
37-
[args[0]],
23+
# Conditionally register NCCL converters only if TensorRT-LLM plugin is available.
24+
# We use an `if` statement instead of @needs_trtllm_for_nccl decorator because
25+
# @dynamo_tensorrt_converter ALWAYS registers at import time regardless of decorator
26+
# order. Conditional registration prevents registration when TRTLLM is unavailable,
27+
# allowing fallback to PyTorch execution for NCCL ops.
28+
29+
# Order 1: @needs_trtllm_for_nccl followed by registering the converter leads to plugin registry not finding nccl ops plugins since we register the bare converter, without the decorator
30+
# Order 2: registering the converter first followed by @needs_trtllm_for_nccl leads to "NotImplementedError: TensorRT-LLM plugin for NCCL is not available :TensorRT-LLM plugin for NCCL is not available" and no fall back to pytorch
31+
if ENABLED_FEATURES.trtllm_for_nccl:
32+
_LOGGER.debug(
33+
"TensorRT-LLM plugin for NCCL is available. Registering NCCL converters."
3834
)
3935

36+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
37+
def fused_nccl_gather(
38+
ctx: ConversionContext,
39+
target: Target,
40+
args: Tuple[Argument, ...],
41+
kwargs: Dict[str, Argument],
42+
name: str,
43+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
44+
return impl.nccl_ops.nccl_gather(
45+
ctx,
46+
target,
47+
SourceIR.ATEN,
48+
name,
49+
[args[0]],
50+
)
51+
52+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
53+
def fused_nccl_reduce_scatter(
54+
ctx: ConversionContext,
55+
target: Target,
56+
args: Tuple[Argument, ...],
57+
kwargs: Dict[str, Argument],
58+
name: str,
59+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
60+
return impl.nccl_ops.nccl_reduce_scatter(
61+
ctx,
62+
target,
63+
SourceIR.ATEN,
64+
name,
65+
[args[0]],
66+
)
4067

41-
@needs_trtllm_for_nccl
42-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
43-
def fused_nccl_reduce_scatter(
44-
ctx: ConversionContext,
45-
target: Target,
46-
args: Tuple[Argument, ...],
47-
kwargs: Dict[str, Argument],
48-
name: str,
49-
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
50-
return impl.nccl_ops.nccl_reduce_scatter(
51-
ctx,
52-
target,
53-
SourceIR.ATEN,
54-
name,
55-
[args[0]],
68+
else:
69+
_LOGGER.info(
70+
"TensorRT-LLM plugin for NCCL is not available. "
71+
"NCCL operations will fall back to PyTorch execution."
5672
)

tests/py/dynamo/distributed/test_nccl_ops.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,24 @@
88
from distributed_utils import set_environment_variables_pytest
99
from parameterized import parameterized
1010
from torch.testing._internal.common_utils import run_tests
11-
from torch_tensorrt._utils import is_platform_supported_for_trtllm
11+
12+
13+
def is_distributed_nccl_available():
14+
"""
15+
Check if torch.distributed with NCCL backend is available.
16+
17+
Note: torch.distributed is available on Windows but NCCL backend is not.
18+
NCCL (NVIDIA Collective Communications Library) is Linux/Unix only.
19+
This function returns False on Windows, Jetson, and other platforms
20+
where NCCL backend is not supported.
21+
"""
22+
try:
23+
import torch.distributed as dist
24+
25+
# Check if NCCL backend is available (False on Windows, since its gloo. For ORIN some torch distribution it is available
26+
return dist.is_nccl_available()
27+
except (ImportError, AttributeError):
28+
return False
1229

1330

1431
class DistributedGatherModel(nn.Module):
@@ -42,9 +59,11 @@ def forward(self, x):
4259

4360

4461
class TestNcclOpsConverter(DispatchTestCase):
62+
# 1. Skip if NCCL backend is not available (e.g., Windows, Jetson) - hard requirement
63+
# 2. Don't skip if TRTLLM is unavailable (e.g., CUDA 13) - falls back to PyTorch
4564
@unittest.skipIf(
46-
not is_platform_supported_for_trtllm(),
47-
"Skipped on Windows, Jetson and CUDA13: NCCL backend is not supported.",
65+
not is_distributed_nccl_available(),
66+
"Skipped: NCCL backend is not available (Windows/Jetson not supported).",
4867
)
4968
@classmethod
5069
def setUpClass(cls):

0 commit comments

Comments
 (0)