Skip to content

[mlir][nvgpu] Make mbarrier.try_wait fallable #96508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chr1sj0nes
Copy link
Contributor

The op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the ticks parameter is actually meaningful.

This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful.
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2024

@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Chris Jones (chr1sj0nes)

Changes

The op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the ticks parameter is actually meaningful.


Full diff: https://github.com/llvm/llvm-project/pull/96508.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+10-26)
  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+2-1)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+1-1)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+20-21)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+11-25)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4d48b3de7a57e..eac71a59841e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
-  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {  
-  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string(
-        "{\n\t"
-        ".reg .pred       P1; \n\t"
-        "LAB_WAIT: \n\t"
-        "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
-        "@P1 bra.uni DONE; \n\t"
-        "bra.uni     LAB_WAIT; \n\t"
-        "DONE: \n\t"
-        "}"
-      ); 
+      return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;");
     }
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
-  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {  
-  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
+  Results<(outs LLVM_Type:$res)>,
+  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
+  let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
-      return std::string(
-        "{\n\t"
-        ".reg .pred       P1; \n\t"
-        "LAB_WAIT: \n\t"
-        "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
-        "@P1 bra.uni DONE; \n\t"
-        "bra.uni     LAB_WAIT; \n\t"
-        "DONE: \n\t"
-        "}"
-      ); 
+      return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;");
     }
   }];
 }
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe..d946a1cf49ee2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
 
     Example:
     ```mlir
-      nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+      %isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
     ```
   }];
   let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+  let results = (outs I1:$waitComplete);
   let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";  
 }
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..bcc1d9eb7766c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering
     Value ticks = truncToI32(b, adaptor.getTicks());
     Value phase =
         b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+    Type retType = rewriter.getI1Type();
 
     if (isMbarrierShared(op.getBarriers().getType())) {
       rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
-          op, barrier, phase, ticks);
+          op, retType, barrier, phase, ticks);
       return success();
     }
 
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
-                                                               phase, ticks);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(
+        op, retType, barrier, phase, ticks);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 4e256aea0be37..356c1621f81e6 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity(
   Value ticksBeforeRetry =
       rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, i1, barrier, parity,
                                                   ticksBeforeRetry, zero);
 }
 
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 86a552c03a473..5d8e5d1e5a2db 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
 }
 
 // CHECK-LABEL: func @mbarrier_txcount
-func.func @mbarrier_txcount() {
+func.func @mbarrier_txcount() -> i1 {
     %num_threads = arith.constant 128 : index
     // CHECK: %[[c0:.+]] = arith.constant 0 : index
     // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
@@ -568,50 +568,49 @@ func.func @mbarrier_txcount() {
     %barrier = nvgpu.mbarrier.create -> !barrierType
 
     // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
     nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
-    
+
     %tidxreg = nvvm.read.ptx.sreg.tid.x : i32
     %tidx = arith.index_cast %tidxreg : i32 to index
-    %cnd = arith.cmpi eq, %tidx, %c0 : index  
+    %cnd = arith.cmpi eq, %tidx, %c0 : index
 
     scf.if %cnd {
       %txcount = arith.constant 256 : index
-      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
       // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
       // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
       nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
-      scf.yield 
+      scf.yield
     } else {
       %txcount = arith.constant 0 : index
-      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+      // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
       // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
       // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
       nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
-      scf.yield 
+      scf.yield
     }
-      
 
     %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
-    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
-    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+    // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
-    func.return 
+    func.return %isDone : i1
 }
 
 // CHECK-LABEL: func @mbarrier_txcount_pred
-func.func @mbarrier_txcount_pred() {
+func.func @mbarrier_txcount_pred() -> i1 {
     %mine = arith.constant 1 : index
     // CHECK: %[[c0:.+]] = arith.constant 0 : index
     // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
     // CHECK: %[[S2:.+]] = gpu.thread_id  x
     // CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
-    %c0 = arith.constant 0 : index  
+    %c0 = arith.constant 0 : index
     %tidx = gpu.thread_id x
     %pred = arith.cmpi eq, %tidx, %c0 : index
 
@@ -619,25 +618,25 @@ func.func @mbarrier_txcount_pred() {
     %barrier = nvgpu.mbarrier.create -> !barrierType
 
     // CHECK: %[[barStr:.+]] =  builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
-    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
     
     %txcount = arith.constant 256 : index
-    // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
     nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
 
     %phase_c0 = arith.constant 0 : i1
     %ticks = arith.constant 10000000 : index
-    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+    // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
     // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
-    // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
-    nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+    // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+    %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
 
-    func.return 
+    func.return %isDone : i1
 }
 
 // CHECK-LABEL: func @async_tma_load
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e..a2ec210985863 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount
 }
 
 // CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "r,r,r"
-   nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
-  llvm.return
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+  // CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;",
+  // CHECK-SAME: "=b,r,r,r"
+  %isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
+  llvm.return %isDone : i1
 }
 
 // CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
-  // CHECK-SAME: "{
-  // CHECK-SAME: .reg .pred       P1;
-  // CHECK-SAME: LAB_WAIT: 
-  // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
-  // CHECK-SAME: @P1 bra.uni DONE;
-  // CHECK-SAME: bra.uni     LAB_WAIT;
-  // CHECK-SAME: DONE:
-  // CHECK-SAME: }",
-  // CHECK-SAME: "l,r,r"
-  nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
-  llvm.return
+  // CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;",
+  // CHECK-SAME: "=b,l,r,r"
+  %isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
+  llvm.return %isDone : i1
 }
 
 // CHECK-LABEL: @async_cp

@grypp
Copy link
Member

grypp commented Jun 24, 2024

This is going to be a breaking change. Have you run sm_90 integrations tests locally? The CI passes here, because we don't have sm_90 GPUs.

Comment on lines -332 to -340
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite useful PTX blob. We could keep this in the OP as default PTX, and add alternate path to single instruction. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having a variant that loops is useful, but my understanding is that the nvvm dialect should map directly onto the PTX primitives. I think the best solution would be to add an nvgpu.mbarrier.wait op which lowers to the PTX above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a new OP. You can add an unit attribute on NVVM_MBarrierTryWaitParityOp that opt-out the loop generation.

@chr1sj0nes
Copy link
Contributor Author

No, I haven't run all the tests. I will do that.

@chr1sj0nes chr1sj0nes changed the title [mlir:nvgpu] Make mbarrier.try_wait fallable [mlir][nvgpu] Make mbarrier.try_wait fallable Jun 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants