Skip to content

Commit 0f03fb1

Browse files
committed
example: using nvrtc kernel for aot plugin
1 parent a80572d commit 0f03fb1

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue.
3+
4+
This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op()
5+
fails to compile operations that use fp16 (half-precision) tensors.
6+
7+
The issue occurs because the JIT plugin generator doesn't properly declare format
8+
support for fp16 data types in the generated TensorRT plugin.
9+
"""
10+
11+
from typing import List, Tuple, Union
12+
13+
import torch
14+
15+
import torch_tensorrt
16+
17+
# CUDA kernel source (NVRTC) used by the torch custom op
18+
# Note: TensorRT passes args as: inputs, extra_args, outputs
19+
20+
cu_code = """
21+
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
22+
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
23+
const int size,
24+
float* __restrict__ output) {
25+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
26+
27+
if (idx < size) {
28+
const float x = input[idx];
29+
// use fast device intrinsic to avoid headers
30+
output[idx] = 1.0f / (1.0f + __expf(-x));
31+
}
32+
}
33+
"""
34+
35+
# Prepare NVRTC program, kernel, and stream once (simple eager path)
36+
from cuda.core.experimental import Device as _CudaDevice
37+
from cuda.core.experimental import LaunchConfig as _LaunchConfig
38+
from cuda.core.experimental import Program as _CudaProgram
39+
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
40+
from cuda.core.experimental import launch as _cuda_launch
41+
42+
_cuda_device = _CudaDevice()
43+
_cuda_device.set_current()
44+
_cuda_stream = _cuda_device.create_stream()
45+
_program_options = _CudaProgramOptions(
46+
std="c++17",
47+
arch=f"sm_{_cuda_device.arch}",
48+
include_path=["/usr/local/cuda/include"],
49+
)
50+
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
51+
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
52+
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")
53+
54+
# Eager torch custom_op implemented using the CUDA kernel above (no Triton)
55+
56+
57+
# ============================================================================
58+
# Custom Op Registration
59+
# ============================================================================
60+
61+
62+
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
63+
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
64+
assert X.is_cuda, "Tensor must be on CUDA device."
65+
assert X.dtype == torch.float32, "For this test, expected float32 input"
66+
67+
Y = torch.empty_like(X)
68+
N = int(X.numel())
69+
70+
block = 256
71+
72+
grid_x = max(1, (N + block - 1) // block)
73+
config = _LaunchConfig(grid=(grid_x), block=(block))
74+
75+
# Use PyTorch's current stream by wrapping it for cuda.core
76+
class _PyTorchStreamWrapper:
77+
def __init__(self, pt_stream):
78+
self.pt_stream = pt_stream
79+
80+
def __cuda_stream__(self):
81+
stream_id = self.pt_stream.cuda_stream
82+
return (0, stream_id)
83+
84+
pt_stream = torch.cuda.current_stream()
85+
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))
86+
87+
# Launch kernel with raw pointers as in cuda.core example
88+
# Note: argument order is input, size, (matching TensorRT's convention)
89+
90+
_cuda_launch(
91+
s,
92+
config,
93+
_kernel,
94+
X.data_ptr(),
95+
N,
96+
Y.data_ptr(),
97+
)
98+
99+
return Y
100+
101+
102+
@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
103+
def _(input: torch.Tensor) -> torch.Tensor:
104+
"""Fake implementation for TorchDynamo tracing of base operation."""
105+
return torch.empty_like(input)
106+
107+
108+
# ============================================================================
109+
# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16
110+
# ============================================================================
111+
112+
import tensorrt.plugin as trtp
113+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions
114+
115+
116+
@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid")
117+
def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
118+
return (input.like(),)
119+
120+
121+
@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid")
122+
def sigmoid_autotune(
123+
input: trtp.TensorDesc,
124+
outputs: Tuple[trtp.TensorDesc],
125+
) -> List[trtp.AutoTuneCombination]:
126+
# Match float32 path; add FP16 if you want both
127+
return [trtp.AutoTuneCombination("FP32, FP32", "LINEAR")]
128+
129+
130+
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
131+
def sigmoid_aot_nvrtc_impl(
132+
input: trtp.TensorDesc,
133+
outputs: Tuple[trtp.TensorDesc],
134+
tactic: int,
135+
) -> Tuple[
136+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
137+
]:
138+
139+
compiled_kernel = _module.code.decode("utf-8")
140+
print(type(compiled_kernel))
141+
print(compiled_kernel)
142+
143+
# import pdb; pdb.set_trace()
144+
145+
N = input.shape_expr.numel()
146+
launch_params = trtp.KernelLaunchParams()
147+
block = 256
148+
launch_params.grid_x = trtp.cdiv(N, block)
149+
launch_params.block_x = block
150+
launch_params.shared_mem = 0
151+
152+
extra_args = trtp.SymIntExprs(1)
153+
extra_args[0] = trtp.SymInt32(N)
154+
155+
return (
156+
"pointwise_sigmoid_kernel_nvrtc",
157+
compiled_kernel,
158+
launch_params,
159+
extra_args,
160+
)
161+
162+
163+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
164+
"pointwise_sigmoid_ops::pointwise_sigmoid",
165+
supports_dynamic_shapes=True,
166+
requires_output_allocator=False,
167+
)
168+
169+
170+
# ============================================================================
171+
# Test Model
172+
# ============================================================================
173+
174+
175+
class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
176+
"""
177+
Test model that uses the TRT wrapper with custom_op() registration.
178+
179+
When compiled with torch_tensorrt.compile() using fp16 inputs, this will
180+
fail with: "could not find any supported formats consistent with input/output
181+
data types"
182+
"""
183+
184+
def forward(self, input: torch.Tensor) -> torch.Tensor:
185+
186+
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
187+
return z
188+
189+
190+
if __name__ == "__main__":
191+
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
192+
input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
193+
194+
print(torch.sigmoid(input))
195+
196+
print(model(input))
197+
198+
with torch_tensorrt.logging.debug():
199+
trt_inputs = [input]
200+
model_trt = torch_tensorrt.compile(
201+
model,
202+
inputs=trt_inputs,
203+
enabled_precisions={torch.float32},
204+
min_block_size=1,
205+
)
206+
print("Model compiled successfully!")
207+
print("Running inference with compiled model...")
208+
print("Compiled model output:")
209+
print(model_trt(input))
210+
print("Original model output:")
211+
print(model(input))
212+
with torch.no_grad():
213+
for i in range(10):
214+
res = model_trt(input)
215+
assert torch.allclose(
216+
res, model(input), rtol=1e-2, atol=1e-2
217+
), "Results do not match!"
218+
219+
# print("Inference successful!")

0 commit comments

Comments
 (0)