-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][nvgpu] Select warpgroup id on warpgroup.mma.store
#85820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
`warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0. This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) Changes
This PR adds a new operand to
Full diff: https://github.com/llvm/llvm-project/pull/85820.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe9..d22b8fd28582c7 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
in $matrixD to given memref.
+ Note that, the op must be run with warp group. The operand `warpgroupId`
+ allow to select the warp group to run the operation. When it is not present,
+ the first warp group runs the operation.
+
[See the details of register fragment layout for accumulator matrix D]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
-
- Note that, the op must be run with warp group.
}];
let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
- Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+ Arg<AnyMemRef, "", [MemWrite]>:$dstMemref,
+ Optional<I32>:$warpgroupId);
let assemblyFormat = [{
- $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ $matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)?
+ attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 19070f6f062a02..0d16b825821252 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
constexpr int kWarpSize = 32;
+/// Number of threads in warpgroup
+constexpr int kWarpgroupSize = 128;
+
/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9b5d19ebd783a9..5a8f095b84c791 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
+ LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal
+ << "\t" << "base_offset:" << offsetVal << "\t"
<< "layout_type:" << swizzle << " ("
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr << "\n");
@@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
+ LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k"
+ << wgmmaK << "(A[" << (iterationM * wgmmaM) << ":"
<< (iterationM * wgmmaM) + wgmmaM << "]["
<< (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
<< wgmmaN << "])\n");
@@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
/// \param offset: the offset within the memref where the registers will be
/// stored.
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
- TypedValue<MemRefType> dstMemref,
- int offset) const {
+ TypedValue<MemRefType> dstMemref, int offset,
+ Value warpgroupId) const {
Type i32 = b.getI32Type();
auto makeConst = [&](int32_t index) -> Value {
@@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
};
Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+ // Normalize the thread index to the beginning of the warpgroup
+ if (warpgroupId) {
+ Value s1 =
+ b.create<arith::MulIOp>(warpgroupId, makeConst(kWarpgroupSize));
+ tidx = b.create<arith::SubIOp>(tidx, s1);
+ }
+
Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
@@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
- storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
+ storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset,
+ adaptor.getWarpgroupId());
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78db..ae9fcd287bdded 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple(
return
}
+// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32)
+func.func @warpgroup_mma_store_multiple_with_id(
+ %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
+ %shmem_m64n16k : memref<64x16xf32, 3>,
+ %id : i32)
+{
+ // CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ // CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ // CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+ // CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32
+ // CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32
+ // CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32
+ // CHECK: %[[s13:.*]] = llvm.urem %12, %8 : i32
+ // CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]] : i32
+ // CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]] : i32
+ // CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]] : i32
+ // CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]] : i32
+ // CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]] : i32
+ // CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]] : i32
+ // CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]] : i32
+ // CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]] : i32
+ // CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]] : i32
+ // CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]] : i32
+ // CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index
+ // CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index
+ // CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]] : i32
+ // CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index
+ // CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
+ // CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3>
+ // CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3>
+ // CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3>
+ nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32,3>
+ return
+}
+
func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
|
The
warpgroup.mma.store
operation is currently using only warpgroup 0, which stores fragmented registers to a destination memory reference.This PR introduces a new feature to this operation: the ability to select a different warpgroup. One can specify a
warpgroup_id
operand. This allows to choose desired warpgroup for storing the data. Here's an example of how you can use it: