-
Notifications
You must be signed in to change notification settings - Fork 370
example: using nvrtc kernel for aot plugin #3881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bowang007
wants to merge
1
commit into
main
Choose a base branch
from
nvrtc_example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| """ | ||
| Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue. | ||
|
|
||
| This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op() | ||
| fails to compile operations that use fp16 (half-precision) tensors. | ||
|
|
||
| The issue occurs because the JIT plugin generator doesn't properly declare format | ||
| support for fp16 data types in the generated TensorRT plugin. | ||
| """ | ||
|
|
||
| from typing import List, Tuple, Union | ||
|
|
||
| import torch | ||
|
|
||
| import torch_tensorrt | ||
|
|
||
| # CUDA kernel source (NVRTC) used by the torch custom op | ||
| # Note: TensorRT passes args as: inputs, extra_args, outputs | ||
|
|
||
| cu_code = """ | ||
| // Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) | ||
| extern "C" __global__ void pointwise_sigmoid_kernel_nvrtc(const float* __restrict__ input, | ||
| const int size, | ||
| float* __restrict__ output) { | ||
| const int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
|
||
| if (idx < size) { | ||
| const float x = input[idx]; | ||
| // use fast device intrinsic to avoid headers | ||
| output[idx] = 1.0f / (1.0f + __expf(-x)); | ||
| } | ||
| } | ||
| """ | ||
|
|
||
| # Prepare NVRTC program, kernel, and stream once (simple eager path) | ||
| from cuda.core.experimental import Device as _CudaDevice | ||
| from cuda.core.experimental import LaunchConfig as _LaunchConfig | ||
| from cuda.core.experimental import Program as _CudaProgram | ||
| from cuda.core.experimental import ProgramOptions as _CudaProgramOptions | ||
| from cuda.core.experimental import launch as _cuda_launch | ||
|
|
||
| _cuda_device = _CudaDevice() | ||
| _cuda_device.set_current() | ||
| _cuda_stream = _cuda_device.create_stream() | ||
| _program_options = _CudaProgramOptions( | ||
| std="c++17", | ||
| arch=f"sm_{_cuda_device.arch}", | ||
| include_path=["/usr/local/cuda/include"], | ||
| ) | ||
| _program = _CudaProgram(cu_code, code_type="c++", options=_program_options) | ||
| _module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) | ||
| _kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc") | ||
|
|
||
| # Eager torch custom_op implemented using the CUDA kernel above (no Triton) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Custom Op Registration | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| @torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] | ||
| def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: | ||
| assert X.is_cuda, "Tensor must be on CUDA device." | ||
| assert X.dtype == torch.float32, "For this test, expected float32 input" | ||
|
|
||
| Y = torch.empty_like(X) | ||
| N = int(X.numel()) | ||
|
|
||
| block = 256 | ||
|
|
||
| grid_x = max(1, (N + block - 1) // block) | ||
| config = _LaunchConfig(grid=(grid_x), block=(block)) | ||
|
|
||
| # Use PyTorch's current stream by wrapping it for cuda.core | ||
| class _PyTorchStreamWrapper: | ||
| def __init__(self, pt_stream): | ||
| self.pt_stream = pt_stream | ||
|
|
||
| def __cuda_stream__(self): | ||
| stream_id = self.pt_stream.cuda_stream | ||
| return (0, stream_id) | ||
|
|
||
| pt_stream = torch.cuda.current_stream() | ||
| s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream)) | ||
|
|
||
| # Launch kernel with raw pointers as in cuda.core example | ||
| # Note: argument order is input, size, (matching TensorRT's convention) | ||
|
|
||
| _cuda_launch( | ||
| s, | ||
| config, | ||
| _kernel, | ||
| X.data_ptr(), | ||
| N, | ||
| Y.data_ptr(), | ||
| ) | ||
|
|
||
| return Y | ||
|
|
||
|
|
||
| @torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") | ||
| def _(input: torch.Tensor) -> torch.Tensor: | ||
| """Fake implementation for TorchDynamo tracing of base operation.""" | ||
| return torch.empty_like(input) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16 | ||
| # ============================================================================ | ||
|
|
||
| import tensorrt.plugin as trtp | ||
| from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions | ||
|
|
||
|
|
||
| @trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid") | ||
| def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: | ||
| return (input.like(),) | ||
|
|
||
|
|
||
| @trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid") | ||
| def sigmoid_autotune( | ||
| input: trtp.TensorDesc, | ||
| outputs: Tuple[trtp.TensorDesc], | ||
| ) -> List[trtp.AutoTuneCombination]: | ||
| # Match float32 path; add FP16 if you want both | ||
| return [trtp.AutoTuneCombination("FP32, FP32", "LINEAR")] | ||
|
|
||
|
|
||
| @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") | ||
| def sigmoid_aot_nvrtc_impl( | ||
| input: trtp.TensorDesc, | ||
| outputs: Tuple[trtp.TensorDesc], | ||
| tactic: int, | ||
| ) -> Tuple[ | ||
| Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs | ||
| ]: | ||
|
|
||
| compiled_kernel = _module.code.decode("utf-8") | ||
| print(type(compiled_kernel)) | ||
| print(compiled_kernel) | ||
|
|
||
| # import pdb; pdb.set_trace() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
|
||
| N = input.shape_expr.numel() | ||
| launch_params = trtp.KernelLaunchParams() | ||
| block = 256 | ||
| launch_params.grid_x = trtp.cdiv(N, block) | ||
| launch_params.block_x = block | ||
| launch_params.shared_mem = 0 | ||
|
|
||
| extra_args = trtp.SymIntExprs(1) | ||
| extra_args[0] = trtp.SymInt32(N) | ||
|
|
||
| return ( | ||
| "pointwise_sigmoid_kernel_nvrtc", | ||
| compiled_kernel, | ||
| launch_params, | ||
| extra_args, | ||
| ) | ||
|
|
||
|
|
||
| torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( | ||
| "pointwise_sigmoid_ops::pointwise_sigmoid", | ||
| supports_dynamic_shapes=True, | ||
| requires_output_allocator=False, | ||
| ) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Test Model | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): | ||
| """ | ||
| Test model that uses the TRT wrapper with custom_op() registration. | ||
|
|
||
| When compiled with torch_tensorrt.compile() using fp16 inputs, this will | ||
| fail with: "could not find any supported formats consistent with input/output | ||
| data types" | ||
| """ | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
|
|
||
| z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input) | ||
| return z | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() | ||
| input = torch.randn(1, 1024, device="cuda", dtype=torch.float32) | ||
|
|
||
| print(torch.sigmoid(input)) | ||
|
|
||
| print(model(input)) | ||
|
|
||
| with torch_tensorrt.logging.debug(): | ||
| trt_inputs = [input] | ||
| model_trt = torch_tensorrt.compile( | ||
| model, | ||
| inputs=trt_inputs, | ||
| enabled_precisions={torch.float32}, | ||
| min_block_size=1, | ||
| ) | ||
| print("Model compiled successfully!") | ||
| print("Running inference with compiled model...") | ||
| print("Compiled model output:") | ||
| print(model_trt(input)) | ||
| print("Original model output:") | ||
| print(model(input)) | ||
| with torch.no_grad(): | ||
| for i in range(10): | ||
| res = model_trt(input) | ||
| assert torch.allclose( | ||
| res, model(input), rtol=1e-2, atol=1e-2 | ||
| ), "Results do not match!" | ||
|
|
||
| # print("Inference successful!") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we use the autogenerated registration for this?