From 9d4bb956dd4e78c0fb9bae69372064a5192a3762 Mon Sep 17 00:00:00 2001 From: grypp Date: Tue, 19 Mar 2024 16:53:28 +0000 Subject: [PATCH] [mlir][nvgpu] Select warpgroup id on `warpgroup.mma.store` `warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0. This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup. --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 12 +++-- .../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 3 ++ .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 30 +++++++----- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 46 +++++++++++++++++++ 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index dda8f31e688fe..d22b8fd28582c 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> { The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result in $matrixD to given memref. + Note that, the op must be run with warp group. The operand `warpgroupId` + allow to select the warp group to run the operation. When it is not present, + the first warp group runs the operation. + [See the details of register fragment layout for accumulator matrix D] (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) - - Note that, the op must be run with warp group. }]; let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD, - Arg:$dstMemref); + Arg:$dstMemref, + Optional:$warpgroupId); let assemblyFormat = [{ - $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref) + $matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)? + attr-dict `:` type($matrixD) `to` type($dstMemref) }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h index 19070f6f062a0..0d16b82582125 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -23,6 +23,9 @@ constexpr int kWarpSize = 32; +/// Number of threads in warpgroup +constexpr int kWarpgroupSize = 128; + /// M size of wgmma.mma_async instruction constexpr int kWgmmaSizeM = 64; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 9b5d19ebd783a..5a8f095b84c79 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" + LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:" + << leadDimVal << "\t" << "stride_off :" << strideDimVal + << "\t" << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) << ")\n start_addr : " << baseAddr << "\n"); @@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" + LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" + << wgmmaK << "(A[" << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " << " B[" + << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN << "])\n"); @@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering /// \param offset: the offset within the memref where the registers will be /// stored. void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, - TypedValue dstMemref, - int offset) const { + TypedValue dstMemref, int offset, + Value warpgroupId) const { Type i32 = b.getI32Type(); auto makeConst = [&](int32_t index) -> Value { @@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering }; Value tidx = b.create(i32); + // Normalize the thread index to the beginning of the warpgroup + if (warpgroupId) { + Value s1 = + b.create(warpgroupId, makeConst(kWarpgroupSize)); + tidx = b.create(tidx, s1); + } + Value laneId = b.create(i32, tidx, warpSize); Value warpId = b.create(i32, tidx, warpSize); Value lane4Id = b.create(i32, laneId, c4); @@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { auto structType = matrixD.cast(); Value innerStructValue = b.create(matriDValue, idx); - storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); + storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset, + adaptor.getWarpgroupId()); offset += structType.getBody().size(); } rewriter.eraseOp(op); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index dbf8ead49f78d..ae9fcd287bdde 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple( return } +// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32) +func.func @warpgroup_mma_store_multiple_with_id( + %res_m64n16k : !nvgpu.warpgroup.accumulator>, + %shmem_m64n16k : memref<64x16xf32, 3>, + %id : i32) +{ + // CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)> + // CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)> + // CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32 + // CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32 + // CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32 + // CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32 + // CHECK: %[[s13:.*]] = llvm.urem %12, %8 : i32 + // CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]] : i32 + // CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]] : i32 + // CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]] : i32 + // CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]] : i32 + // CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]] : i32 + // CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]] : i32 + // CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]] : i32 + // CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]] : i32 + // CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]] : i32 + // CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]] : i32 + // CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index + // CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index + // CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]] : i32 + // CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index + // CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3> + // CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3> + // CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3> + nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator> to memref<64x16xf32,3> + return +} + func.func @warpgroup_mma_init() { //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3 //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>