-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][nvgpu] Make mbarrier.try_wait
fallable
#96508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful.
@llvm/pr-subscribers-mlir-nvgpu @llvm/pr-subscribers-mlir Author: Chris Jones (chr1sj0nes) ChangesThe op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the Full diff: https://github.com/llvm/llvm-project/pull/96508.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4d48b3de7a57e..eac71a59841e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex
}];
}
-def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
+ return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;");
}
}];
}
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
+ Results<(outs LLVM_Type:$res)>,
+ Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
+ let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
+ return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;");
}
}];
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe..d946a1cf49ee2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {
Example:
```mlir
- nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
+ %isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
```
}];
let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
+ let results = (outs I1:$waitComplete);
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";
}
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..bcc1d9eb7766c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
+ Type retType = rewriter.getI1Type();
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
+ op, retType, barrier, phase, ticks);
return success();
}
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
- phase, ticks);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(
+ op, retType, barrier, phase, ticks);
return success();
}
};
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 4e256aea0be37..356c1621f81e6 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity(
Value ticksBeforeRetry =
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
+ rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, i1, barrier, parity,
ticksBeforeRetry, zero);
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 86a552c03a473..5d8e5d1e5a2db 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
}
// CHECK-LABEL: func @mbarrier_txcount
-func.func @mbarrier_txcount() {
+func.func @mbarrier_txcount() -> i1 {
%num_threads = arith.constant 128 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
@@ -568,50 +568,49 @@ func.func @mbarrier_txcount() {
%barrier = nvgpu.mbarrier.create -> !barrierType
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType
-
+
%tidxreg = nvvm.read.ptx.sreg.tid.x : i32
%tidx = arith.index_cast %tidxreg : i32 to index
- %cnd = arith.cmpi eq, %tidx, %c0 : index
+ %cnd = arith.cmpi eq, %tidx, %c0 : index
scf.if %cnd {
%txcount = arith.constant 256 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
- scf.yield
+ scf.yield
} else {
%txcount = arith.constant 0 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
- scf.yield
+ scf.yield
}
-
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
- // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
- nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+ // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
- func.return
+ func.return %isDone : i1
}
// CHECK-LABEL: func @mbarrier_txcount_pred
-func.func @mbarrier_txcount_pred() {
+func.func @mbarrier_txcount_pred() -> i1 {
%mine = arith.constant 1 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
// CHECK: %[[S2:.+]] = gpu.thread_id x
// CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
- %c0 = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
%tidx = gpu.thread_id x
%pred = arith.cmpi eq, %tidx, %c0 : index
@@ -619,25 +618,25 @@ func.func @mbarrier_txcount_pred() {
%barrier = nvgpu.mbarrier.create -> !barrierType
// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType
%txcount = arith.constant 256 : index
- // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
- // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
- nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
+ // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
- func.return
+ func.return %isDone : i1
}
// CHECK-LABEL: func @async_tma_load
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 21947c242461e..a2ec210985863 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount
}
// CHECK-LABEL: @init_mbarrier_try_wait_shared
-llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
- llvm.return
+llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;",
+ // CHECK-SAME: "=b,r,r,r"
+ %isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
+ llvm.return %isDone : i1
}
// CHECK-LABEL: @init_mbarrier_try_wait
-llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
+llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "{
- // CHECK-SAME: .reg .pred P1;
- // CHECK-SAME: LAB_WAIT:
- // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
- // CHECK-SAME: @P1 bra.uni DONE;
- // CHECK-SAME: bra.uni LAB_WAIT;
- // CHECK-SAME: DONE:
- // CHECK-SAME: }",
- // CHECK-SAME: "l,r,r"
- nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
- llvm.return
+ // CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;",
+ // CHECK-SAME: "=b,l,r,r"
+ %isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
+ llvm.return %isDone : i1
}
// CHECK-LABEL: @async_cp
|
This is going to be a breaking change. Have you run sm_90 integrations tests locally? The CI passes here, because we don't have sm_90 GPUs. |
"{\n\t" | ||
".reg .pred P1; \n\t" | ||
"LAB_WAIT: \n\t" | ||
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" | ||
"@P1 bra.uni DONE; \n\t" | ||
"bra.uni LAB_WAIT; \n\t" | ||
"DONE: \n\t" | ||
"}" | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite useful PTX blob. We could keep this in the OP as default PTX, and add alternate path to single instruction. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that having a variant that loops is useful, but my understanding is that the nvvm
dialect should map directly onto the PTX primitives. I think the best solution would be to add an nvgpu.mbarrier.wait
op which lowers to the PTX above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a new OP. You can add an unit attribute on NVVM_MBarrierTryWaitParityOp
that opt-out the loop generation.
No, I haven't run all the tests. I will do that. |
mbarrier.try_wait
fallablembarrier.try_wait
fallable
The op will make a single attempt to wait, then fail, rather than looping. This matches the behaviour described in the docs, and means that the
ticks
parameter is actually meaningful.