diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index dda8f31e688fe..a128acd7d47dd 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -526,22 +526,6 @@ def NVGPU_MBarrierInitOp : NVGPU_Op<"mbarrier.init", []> { let assemblyFormat = "$barriers `[` $mbarId `]` `,` $count (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)"; } -def NVGPU_MBarrierTestWaitOp : NVGPU_Op<"mbarrier.test.wait", []> { - let summary = "Checks if the `nvgpu.mbarrier` has completed its current phase."; - let description = [{ - Checks whether the mbarrier object has completed the phase. It is is a - non-blocking instruction which tests for the completion of the phase. - - Example: - ```mlir - %isComplete = nvgpu.mbarrier.test.wait %barrier, %token : !nvgpu.mbarrier.barrier>, !nvgpu.mbarrier.token - ``` - }]; - let arguments = (ins NVGPU_MBarrierGroup:$barriers, NVGPU_MBarrierToken:$token, Index:$mbarId); - let results = (outs I1:$waitComplete); - let assemblyFormat = "$barriers `[` $mbarId `]` `,` $token attr-dict `:` type($barriers) `,` type($token)"; -} - def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier.arrive", []> { let summary = "Performs arrive operation on the `nvgpu.mbarrier.arrive`."; let description = [{ @@ -601,6 +585,22 @@ def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> { let assemblyFormat = "$barriers `[` $mbarId `]` `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type($barriers)"; } +def NVGPU_MBarrierTestOp : NVGPU_Op<"mbarrier.test", []> { + let summary = "Checks if the `nvgpu.mbarrier` has completed its current phase."; + let description = [{ + Checks whether the mbarrier object has completed the phase. It is is a + non-blocking instruction which tests for the completion of the phase. + + Example: + ```mlir + %isComplete = nvgpu.mbarrier.test %barrier, %token : !nvgpu.mbarrier.barrier>, !nvgpu.mbarrier.token + ``` + }]; + let arguments = (ins NVGPU_MBarrierGroup:$barriers, NVGPU_MBarrierToken:$token, Index:$mbarId); + let results = (outs I1:$waitComplete); + let assemblyFormat = "$barriers `[` $mbarId `]` `,` $token attr-dict `:` type($barriers) `,` type($token)"; +} + def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> { let summary = "Waits for the `nvgpu.mbarrier` to complete its current phase."; let description = [{ diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 11d29754aa760..f610a0ecfdb7f 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -896,12 +896,12 @@ struct NVGPUMBarrierArriveNoCompleteLowering } }; -/// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` -struct NVGPUMBarrierTestWaitLowering - : public MBarrierBasePattern { - using MBarrierBasePattern::MBarrierBasePattern; +/// Lowers `nvgpu.mbarrier.test` to `nvvm.mbarrier.test.wait` +struct NVGPUMBarrierTestLowering + : public MBarrierBasePattern { + using MBarrierBasePattern::MBarrierBasePattern; LogicalResult - matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, + matchAndRewrite(nvgpu::MBarrierTestOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value barrier = @@ -1675,7 +1675,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete - NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity + NVGPUMBarrierTestLowering, // nvgpu.mbarrier.test NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 86a552c03a473..2454532d3e11e 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -497,7 +497,7 @@ func.func @mbarrier() { // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]] - %isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType + %isDone = nvgpu.mbarrier.test %barrier[%c0], %token : !barrierType, !tokenType func.return } @@ -527,7 +527,7 @@ func.func @mbarrier_nocomplete() { // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.test.wait.shared %[[barPtr3]], %[[token]] - %isDone = nvgpu.mbarrier.test.wait %barrier[%c0], %token : !barrierType, !tokenType + %isDone = nvgpu.mbarrier.test %barrier[%c0], %token : !barrierType, !tokenType func.return } @@ -552,7 +552,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.test.wait.shared {{.*}}, %[[CARG1]] %mbarId = arith.remui %i, %numBarriers : index - %isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group, num_barriers = 5>, !tokenType + %isDone = nvgpu.mbarrier.test %barriers[%mbarId], %token : !nvgpu.mbarrier.group, num_barriers = 5>, !tokenType } return }