Skip to content

[mlir][nvgpu] Make mbarrier.try_wait fallable #96508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 10 additions & 26 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex
}];
}

def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
);
Comment on lines -332 to -340
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite useful PTX blob. We could keep this in the OP as default PTX, and add alternate path to single instruction. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having a variant that loops is useful, but my understanding is that the nvvm dialect should map directly onto the PTX primitives. I think the best solution would be to add an nvgpu.mbarrier.wait op which lowers to the PTX above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a new OP. You can add an unit attribute on NVVM_MBarrierTryWaitParityOp that opt-out the loop generation.

return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;");
}
}];
}

def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra.uni DONE; \n\t"
"bra.uni LAB_WAIT; \n\t"
"DONE: \n\t"
"}"
);
return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;");
}
}];
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> {

Example:
```mlir
nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
%isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier<memorySpace = #gpu.address_space<workgroup>>
```
}];
let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId);
let results = (outs I1:$waitComplete);
let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)";
}

Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity());
Type retType = rewriter.getI1Type();

if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
op, barrier, phase, ticks);
op, retType, barrier, phase, ticks);
return success();
}

rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(
op, retType, barrier, phase, ticks);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity(
Value ticksBeforeRetry =
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, i1, barrier, parity,
ticksBeforeRetry, zero);
}

Expand Down
41 changes: 20 additions & 21 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
}

// CHECK-LABEL: func @mbarrier_txcount
func.func @mbarrier_txcount() {
func.func @mbarrier_txcount() -> i1 {
%num_threads = arith.constant 128 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
Expand All @@ -568,76 +568,75 @@ func.func @mbarrier_txcount() {
%barrier = nvgpu.mbarrier.create -> !barrierType

// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]]
nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType

%tidxreg = nvvm.read.ptx.sreg.tid.x : i32
%tidx = arith.index_cast %tidxreg : i32 to index
%cnd = arith.cmpi eq, %tidx, %c0 : index
%cnd = arith.cmpi eq, %tidx, %c0 : index

scf.if %cnd {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
scf.yield
}


%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
// CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
%isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

func.return
func.return %isDone : i1
}

// CHECK-LABEL: func @mbarrier_txcount_pred
func.func @mbarrier_txcount_pred() {
func.func @mbarrier_txcount_pred() -> i1 {
%mine = arith.constant 1 : index
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64
// CHECK: %[[S2:.+]] = gpu.thread_id x
// CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index
%c0 = arith.constant 0 : index
%c0 = arith.constant 0 : index
%tidx = gpu.thread_id x
%pred = arith.cmpi eq, %tidx, %c0 : index

// CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier{{.*}} : memref<1xi64, 3>
%barrier = nvgpu.mbarrier.create -> !barrierType

// CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType

%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType

%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
// CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
// CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
%isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType

func.return
func.return %isDone : i1
}

// CHECK-LABEL: func @async_tma_load
Expand Down
36 changes: 11 additions & 25 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount
}

// CHECK-LABEL: @init_mbarrier_try_wait_shared
llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: "{
// CHECK-SAME: .reg .pred P1;
// CHECK-SAME: LAB_WAIT:
// CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2;
// CHECK-SAME: @P1 bra.uni DONE;
// CHECK-SAME: bra.uni LAB_WAIT;
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;",
// CHECK-SAME: "=b,r,r,r"
%isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1
llvm.return %isDone : i1
}

// CHECK-LABEL: @init_mbarrier_try_wait
llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){
llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: "{
// CHECK-SAME: .reg .pred P1;
// CHECK-SAME: LAB_WAIT:
// CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2;
// CHECK-SAME: @P1 bra.uni DONE;
// CHECK-SAME: bra.uni LAB_WAIT;
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "l,r,r"
nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32
llvm.return
// CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;",
// CHECK-SAME: "=b,l,r,r"
%isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1
llvm.return %isDone : i1
}

// CHECK-LABEL: @async_cp
Expand Down
Loading