|
| 1 | +""" |
| 2 | +Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue. |
| 3 | +
|
| 4 | +This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op() |
| 5 | +fails to compile operations that use fp16 (half-precision) tensors. |
| 6 | +
|
| 7 | +The issue occurs because the JIT plugin generator doesn't properly declare format |
| 8 | +support for fp16 data types in the generated TensorRT plugin. |
| 9 | +""" |
| 10 | + |
| 11 | +from typing import List, Tuple, Union |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +import torch_tensorrt |
| 16 | + |
| 17 | +# CUDA kernel source (NVRTC) used by the torch custom op |
| 18 | +# Note: TensorRT passes args as: inputs, extra_args, outputs |
| 19 | + |
| 20 | +cu_code = """ |
| 21 | +// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) |
| 22 | +extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input, |
| 23 | + const int size, |
| 24 | + float* __restrict__ output) { |
| 25 | + const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 26 | +
|
| 27 | + if (idx < size) { |
| 28 | + const float x = input[idx]; |
| 29 | + // use fast device intrinsic to avoid headers |
| 30 | + output[idx] = 1.0f / (1.0f + __expf(-x)); |
| 31 | + } |
| 32 | +} |
| 33 | +""" |
| 34 | + |
| 35 | +# Prepare NVRTC program, kernel, and stream once (simple eager path) |
| 36 | +from cuda.core.experimental import Device as _CudaDevice |
| 37 | +from cuda.core.experimental import LaunchConfig as _LaunchConfig |
| 38 | +from cuda.core.experimental import Program as _CudaProgram |
| 39 | +from cuda.core.experimental import ProgramOptions as _CudaProgramOptions |
| 40 | +from cuda.core.experimental import launch as _cuda_launch |
| 41 | + |
| 42 | +_cuda_device = _CudaDevice() |
| 43 | +_cuda_device.set_current() |
| 44 | +_cuda_stream = _cuda_device.create_stream() |
| 45 | +_program_options = _CudaProgramOptions( |
| 46 | + std="c++17", |
| 47 | + arch=f"sm_{_cuda_device.arch}", |
| 48 | + include_path=["/usr/local/cuda/include"], |
| 49 | +) |
| 50 | +_program = _CudaProgram(cu_code, code_type="c++", options=_program_options) |
| 51 | +_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) |
| 52 | +_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc") |
| 53 | + |
| 54 | +# Eager torch custom_op implemented using the CUDA kernel above (no Triton) |
| 55 | + |
| 56 | + |
| 57 | +# ============================================================================ |
| 58 | +# Custom Op Registration |
| 59 | +# ============================================================================ |
| 60 | + |
| 61 | + |
| 62 | +@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] |
| 63 | +def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: |
| 64 | + assert X.is_cuda, "Tensor must be on CUDA device." |
| 65 | + assert X.dtype == torch.float32, "For this test, expected float32 input" |
| 66 | + |
| 67 | + Y = torch.empty_like(X) |
| 68 | + N = int(X.numel()) |
| 69 | + |
| 70 | + block = 256 |
| 71 | + |
| 72 | + grid_x = max(1, (N + block - 1) // block) |
| 73 | + config = _LaunchConfig(grid=(grid_x), block=(block)) |
| 74 | + |
| 75 | + # Use PyTorch's current stream by wrapping it for cuda.core |
| 76 | + class _PyTorchStreamWrapper: |
| 77 | + def __init__(self, pt_stream): |
| 78 | + self.pt_stream = pt_stream |
| 79 | + |
| 80 | + def __cuda_stream__(self): |
| 81 | + stream_id = self.pt_stream.cuda_stream |
| 82 | + return (0, stream_id) |
| 83 | + |
| 84 | + pt_stream = torch.cuda.current_stream() |
| 85 | + s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream)) |
| 86 | + |
| 87 | + # Launch kernel with raw pointers as in cuda.core example |
| 88 | + # Note: argument order is input, size, (matching TensorRT's convention) |
| 89 | + |
| 90 | + _cuda_launch( |
| 91 | + s, |
| 92 | + config, |
| 93 | + _kernel, |
| 94 | + X.data_ptr(), |
| 95 | + N, |
| 96 | + Y.data_ptr(), |
| 97 | + ) |
| 98 | + |
| 99 | + return Y |
| 100 | + |
| 101 | + |
| 102 | +@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 103 | +def _(input: torch.Tensor) -> torch.Tensor: |
| 104 | + """Fake implementation for TorchDynamo tracing of base operation.""" |
| 105 | + return torch.empty_like(input) |
| 106 | + |
| 107 | + |
| 108 | +# ============================================================================ |
| 109 | +# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16 |
| 110 | +# ============================================================================ |
| 111 | + |
| 112 | +import tensorrt.plugin as trtp |
| 113 | +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions |
| 114 | + |
| 115 | + |
| 116 | +@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 117 | +def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: |
| 118 | + return (input.like(),) |
| 119 | + |
| 120 | + |
| 121 | +@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 122 | +def sigmoid_autotune( |
| 123 | + input: trtp.TensorDesc, |
| 124 | + outputs: Tuple[trtp.TensorDesc], |
| 125 | +) -> List[trtp.AutoTuneCombination]: |
| 126 | + # Match float32 path; add FP16 if you want both |
| 127 | + return [trtp.AutoTuneCombination("FP32, FP32", "LINEAR")] |
| 128 | + |
| 129 | + |
| 130 | +@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 131 | +def sigmoid_aot_nvrtc_impl( |
| 132 | + input: trtp.TensorDesc, |
| 133 | + outputs: Tuple[trtp.TensorDesc], |
| 134 | + tactic: int, |
| 135 | +) -> Tuple[ |
| 136 | + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 137 | +]: |
| 138 | + |
| 139 | + compiled_kernel = _module.code.decode("utf-8") |
| 140 | + print(type(compiled_kernel)) |
| 141 | + print(compiled_kernel) |
| 142 | + |
| 143 | + # import pdb; pdb.set_trace() |
| 144 | + |
| 145 | + N = input.shape_expr.numel() |
| 146 | + launch_params = trtp.KernelLaunchParams() |
| 147 | + block = 256 |
| 148 | + launch_params.grid_x = trtp.cdiv(N, block) |
| 149 | + launch_params.block_x = block |
| 150 | + launch_params.shared_mem = 0 |
| 151 | + |
| 152 | + extra_args = trtp.SymIntExprs(1) |
| 153 | + extra_args[0] = trtp.SymInt32(N) |
| 154 | + |
| 155 | + return ( |
| 156 | + "pointwise_sigmoid_kernel_nvrtc", |
| 157 | + compiled_kernel, |
| 158 | + launch_params, |
| 159 | + extra_args, |
| 160 | + ) |
| 161 | + |
| 162 | + |
| 163 | +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( |
| 164 | + "pointwise_sigmoid_ops::pointwise_sigmoid", |
| 165 | + supports_dynamic_shapes=True, |
| 166 | + requires_output_allocator=False, |
| 167 | +) |
| 168 | + |
| 169 | + |
| 170 | +# ============================================================================ |
| 171 | +# Test Model |
| 172 | +# ============================================================================ |
| 173 | + |
| 174 | + |
| 175 | +class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): |
| 176 | + """ |
| 177 | + Test model that uses the TRT wrapper with custom_op() registration. |
| 178 | +
|
| 179 | + When compiled with torch_tensorrt.compile() using fp16 inputs, this will |
| 180 | + fail with: "could not find any supported formats consistent with input/output |
| 181 | + data types" |
| 182 | + """ |
| 183 | + |
| 184 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 185 | + |
| 186 | + z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input) |
| 187 | + return z |
| 188 | + |
| 189 | + |
| 190 | +if __name__ == "__main__": |
| 191 | + model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() |
| 192 | + input = torch.randn(1, 1024, device="cuda", dtype=torch.float32) |
| 193 | + |
| 194 | + print(torch.sigmoid(input)) |
| 195 | + |
| 196 | + print(model(input)) |
| 197 | + |
| 198 | + with torch_tensorrt.logging.debug(): |
| 199 | + trt_inputs = [input] |
| 200 | + model_trt = torch_tensorrt.compile( |
| 201 | + model, |
| 202 | + inputs=trt_inputs, |
| 203 | + enabled_precisions={torch.float32}, |
| 204 | + min_block_size=1, |
| 205 | + ) |
| 206 | + print("Model compiled successfully!") |
| 207 | + print("Running inference with compiled model...") |
| 208 | + print("Compiled model output:") |
| 209 | + print(model_trt(input)) |
| 210 | + print("Original model output:") |
| 211 | + print(model(input)) |
| 212 | + with torch.no_grad(): |
| 213 | + for i in range(10): |
| 214 | + res = model_trt(input) |
| 215 | + assert torch.allclose( |
| 216 | + res, model(input), rtol=1e-2, atol=1e-2 |
| 217 | + ), "Results do not match!" |
| 218 | + |
| 219 | + # print("Inference successful!") |
0 commit comments