From 316efd4ee6b86237267fc296892c7d03aa428333 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 24 Jun 2024 16:16:07 +0100 Subject: [PATCH] [mlir:nvgpu] Make `mbarrier.try_wait` non-blocking. This matches the behaviour described in the docs, and means that the `ticks` parameter is actually meaningful. --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 36 +++++----------- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 3 +- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 7 ++-- .../NVGPU/TransformOps/NVGPUTransformOps.cpp | 2 +- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 41 +++++++++---------- .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 36 +++++----------- 6 files changed, 48 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4d48b3de7a57e..eac71a59841e1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -323,40 +323,24 @@ def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.ex }]; } -def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">, - Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> { - let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; +def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> { + let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); + return std::string("mbarrier.try_wait.parity.b64 %0, [%1], %2, %3;"); } }]; } -def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">, - Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> { - let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; +def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> { + let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ std::string $cppClass::getPtx() { - return std::string( - "{\n\t" - ".reg .pred P1; \n\t" - "LAB_WAIT: \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" - "@P1 bra.uni DONE; \n\t" - "bra.uni LAB_WAIT; \n\t" - "DONE: \n\t" - "}" - ); + return std::string("mbarrier.try_wait.parity.shared.b64 %0, [%1], %2, %3;"); } }]; } diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index dda8f31e688fe..d946a1cf49ee2 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -614,10 +614,11 @@ def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> { Example: ```mlir - nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier> + %isComplete = nvgpu.mbarrier.try_wait.parity %barrier, %phaseParity, %ticks : !nvgpu.mbarrier.barrier> ``` }]; let arguments = (ins NVGPU_MBarrierGroup:$barriers, I1:$phaseParity, Index:$ticks, Index:$mbarId); + let results = (outs I1:$waitComplete); let assemblyFormat = "$barriers `[` $mbarId `]` `,` $phaseParity `,` $ticks attr-dict `:` type($barriers)"; } diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 11d29754aa760..bcc1d9eb7766c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -958,15 +958,16 @@ struct NVGPUMBarrierTryWaitParityLowering Value ticks = truncToI32(b, adaptor.getTicks()); Value phase = b.create(b.getI32Type(), adaptor.getPhaseParity()); + Type retType = rewriter.getI1Type(); if (isMbarrierShared(op.getBarriers().getType())) { rewriter.replaceOpWithNewOp( - op, barrier, phase, ticks); + op, retType, barrier, phase, ticks); return success(); } - rewriter.replaceOpWithNewOp(op, barrier, - phase, ticks); + rewriter.replaceOpWithNewOp( + op, retType, barrier, phase, ticks); return success(); } }; diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 4e256aea0be37..356c1621f81e6 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -1018,7 +1018,7 @@ void HopperBuilder::buildTryWaitParity( Value ticksBeforeRetry = rewriter.create(loc, 10000000); Value zero = rewriter.create(loc, 0); - rewriter.create(loc, barrier, parity, + rewriter.create(loc, i1, barrier, parity, ticksBeforeRetry, zero); } diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 86a552c03a473..5d8e5d1e5a2db 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -558,7 +558,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group i1 { %num_threads = arith.constant 128 : index // CHECK: %[[c0:.+]] = arith.constant 0 : index // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64 @@ -568,50 +568,49 @@ func.func @mbarrier_txcount() { %barrier = nvgpu.mbarrier.create -> !barrierType // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.init.shared %[[barPtr]] nvgpu.mbarrier.init %barrier[%c0], %num_threads : !barrierType - + %tidxreg = nvvm.read.ptx.sreg.tid.x : i32 %tidx = arith.index_cast %tidxreg : i32 to index - %cnd = arith.cmpi eq, %tidx, %c0 : index + %cnd = arith.cmpi eq, %tidx, %c0 : index scf.if %cnd { %txcount = arith.constant 256 : index - // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType - scf.yield + scf.yield } else { %txcount = arith.constant 0 : index - // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType - scf.yield + scf.yield } - %phase_c0 = arith.constant 0 : i1 %ticks = arith.constant 10000000 : index - // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] - nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType + // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType - func.return + func.return %isDone : i1 } // CHECK-LABEL: func @mbarrier_txcount_pred -func.func @mbarrier_txcount_pred() { +func.func @mbarrier_txcount_pred() -> i1 { %mine = arith.constant 1 : index // CHECK: %[[c0:.+]] = arith.constant 0 : index // CHECK: %[[mid:.+]] = builtin.unrealized_conversion_cast %[[c0]] : index to i64 // CHECK: %[[S2:.+]] = gpu.thread_id x // CHECK: %[[P:.+]] = arith.cmpi eq, %[[S2]], %[[c0]] : index - %c0 = arith.constant 0 : index + %c0 = arith.constant 0 : index %tidx = gpu.thread_id x %pred = arith.cmpi eq, %tidx, %c0 : index @@ -619,25 +618,25 @@ func.func @mbarrier_txcount_pred() { %barrier = nvgpu.mbarrier.create -> !barrierType // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr:.+]] = llvm.getelementptr %[[base]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.init.shared %[[barPtr]], {{.*}}, predicate = %[[P]] nvgpu.mbarrier.init %barrier[%c0], %mine, predicate = %pred : !barrierType %txcount = arith.constant 256 : index - // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]] nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType %phase_c0 = arith.constant 0 : i1 %ticks = arith.constant 10000000 : index - // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] - nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType + // CHECK: %[[isDone:.+]] = nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + %isDone = nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType - func.return + func.return %isDone : i1 } // CHECK-LABEL: func @async_tma_load diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 21947c242461e..a2ec210985863 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -32,35 +32,21 @@ llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount } // CHECK-LABEL: @init_mbarrier_try_wait_shared -llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "{ - // CHECK-SAME: .reg .pred P1; - // CHECK-SAME: LAB_WAIT: - // CHECK-SAME: mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2; - // CHECK-SAME: @P1 bra.uni DONE; - // CHECK-SAME: bra.uni LAB_WAIT; - // CHECK-SAME: DONE: - // CHECK-SAME: }", - // CHECK-SAME: "r,r,r" - nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 - llvm.return +llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) -> i1 { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: "mbarrier.try_wait.parity.shared.b64 $0, [$1], $2, $3;", + // CHECK-SAME: "=b,r,r,r" + %isDone = nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 -> i1 + llvm.return %isDone : i1 } // CHECK-LABEL: @init_mbarrier_try_wait -llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){ +llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32) -> i1 { // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "{ - // CHECK-SAME: .reg .pred P1; - // CHECK-SAME: LAB_WAIT: - // CHECK-SAME: mbarrier.try_wait.parity.b64 P1, [$0], $1, $2; - // CHECK-SAME: @P1 bra.uni DONE; - // CHECK-SAME: bra.uni LAB_WAIT; - // CHECK-SAME: DONE: - // CHECK-SAME: }", - // CHECK-SAME: "l,r,r" - nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 - llvm.return + // CHECK-SAME: "mbarrier.try_wait.parity.b64 $0, [$1], $2, $3;", + // CHECK-SAME: "=b,l,r,r" + %isDone = nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 -> i1 + llvm.return %isDone : i1 } // CHECK-LABEL: @async_cp