Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions examples/dynamo/nvrtc_aot_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""
Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue.

This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op()
fails to compile operations that use fp16 (half-precision) tensors.

The issue occurs because the JIT plugin generator doesn't properly declare format
support for fp16 data types in the generated TensorRT plugin.
"""

from typing import List, Tuple, Union

import torch

import torch_tensorrt

# CUDA kernel source (NVRTC) used by the torch custom op
# Note: TensorRT passes args as: inputs, extra_args, outputs

cu_code = """
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input,
const int size,
float* __restrict__ output) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx < size) {
const float x = input[idx];
// use fast device intrinsic to avoid headers
output[idx] = 1.0f / (1.0f + __expf(-x));
}
}
"""

# Prepare NVRTC program, kernel, and stream once (simple eager path)
from cuda.core.experimental import Device as _CudaDevice
from cuda.core.experimental import LaunchConfig as _LaunchConfig
from cuda.core.experimental import Program as _CudaProgram
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
from cuda.core.experimental import launch as _cuda_launch

_cuda_device = _CudaDevice()
_cuda_device.set_current()
_cuda_stream = _cuda_device.create_stream()
_program_options = _CudaProgramOptions(
std="c++17",
arch=f"sm_{_cuda_device.arch}",
include_path=["/usr/local/cuda/include"],
)
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")

# Eager torch custom_op implemented using the CUDA kernel above (no Triton)


# ============================================================================
# Custom Op Registration
# ============================================================================


@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
assert X.is_cuda, "Tensor must be on CUDA device."
assert X.dtype == torch.float32, "For this test, expected float32 input"

Y = torch.empty_like(X)
N = int(X.numel())

block = 256

grid_x = max(1, (N + block - 1) // block)
config = _LaunchConfig(grid=(grid_x), block=(block))

# Use PyTorch's current stream by wrapping it for cuda.core
class _PyTorchStreamWrapper:
def __init__(self, pt_stream):
self.pt_stream = pt_stream

def __cuda_stream__(self):
stream_id = self.pt_stream.cuda_stream
return (0, stream_id)

pt_stream = torch.cuda.current_stream()
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))

# Launch kernel with raw pointers as in cuda.core example
# Note: argument order is input, size, (matching TensorRT's convention)

_cuda_launch(
s,
config,
_kernel,
X.data_ptr(),
N,
Y.data_ptr(),
)

return Y


@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
def _(input: torch.Tensor) -> torch.Tensor:
"""Fake implementation for TorchDynamo tracing of base operation."""
return torch.empty_like(input)


# ============================================================================
# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16
# ============================================================================

import tensorrt.plugin as trtp
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions


@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use the autogenerated registration for this?

return (input.like(),)


@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_autotune(
input: trtp.TensorDesc,
outputs: Tuple[trtp.TensorDesc],
) -> List[trtp.AutoTuneCombination]:
# Match float32 path; add FP16 if you want both
return [trtp.AutoTuneCombination("FP32, FP32", "LINEAR")]


@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
def sigmoid_aot_nvrtc_impl(
input: trtp.TensorDesc,
outputs: Tuple[trtp.TensorDesc],
tactic: int,
) -> Tuple[
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
]:

compiled_kernel = _module.code.decode("utf-8")
print(type(compiled_kernel))
print(compiled_kernel)

# import pdb; pdb.set_trace()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


N = input.shape_expr.numel()
launch_params = trtp.KernelLaunchParams()
block = 256
launch_params.grid_x = trtp.cdiv(N, block)
launch_params.block_x = block
launch_params.shared_mem = 0

extra_args = trtp.SymIntExprs(1)
extra_args[0] = trtp.SymInt32(N)

return (
"pointwise_sigmoid_kernel_nvrtc",
compiled_kernel,
launch_params,
extra_args,
)


torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"pointwise_sigmoid_ops::pointwise_sigmoid",
supports_dynamic_shapes=True,
requires_output_allocator=False,
)


# ============================================================================
# Test Model
# ============================================================================


class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
"""
Test model that uses the TRT wrapper with custom_op() registration.

When compiled with torch_tensorrt.compile() using fp16 inputs, this will
fail with: "could not find any supported formats consistent with input/output
data types"
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:

z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
return z


if __name__ == "__main__":
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)

print(torch.sigmoid(input))

print(model(input))

with torch_tensorrt.logging.debug():
trt_inputs = [input]
model_trt = torch_tensorrt.compile(
model,
inputs=trt_inputs,
enabled_precisions={torch.float32},
min_block_size=1,
)
print("Model compiled successfully!")
print("Running inference with compiled model...")
print("Compiled model output:")
print(model_trt(input))
print("Original model output:")
print(model(input))
with torch.no_grad():
for i in range(10):
res = model_trt(input)
assert torch.allclose(
res, model(input), rtol=1e-2, atol=1e-2
), "Results do not match!"

# print("Inference successful!")
Loading