Skip to content

[mlir][nvgpu] Select warpgroup id on warpgroup.mma.store #85820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

grypp
Copy link
Member

@grypp grypp commented Mar 19, 2024

The warpgroup.mma.store operation is currently using only warpgroup 0, which stores fragmented registers to a destination memory reference.

This PR introduces a new feature to this operation: the ability to select a different warpgroup. One can specify a warpgroup_id operand. This allows to choose desired warpgroup for storing the data. Here's an example of how you can use it:

nvgpu.warpgroup.mma.store 
%res_m64n16k, 
%shmem_m64n16k, 
warpgroup_id = %id // <-- PR adds this one
: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> 
  to memref<64x16xf32,3>

`warpgroup.mma.store` Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0.

This PR adds a new operand to `warpgroup.mma.store` Op that allows selecting different warpgroup.
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

warpgroup.mma.store Op is run by a warpgroup that stores fragmented registers to destination memref. Currently, this op is always uses warpgroup 0.

This PR adds a new operand to warpgroup.mma.store Op that allows selecting different warpgroup. For example:

nvgpu.warpgroup.mma.store 
%res_m64n16k, 
%shmem_m64n16k, 
warpgroup_id = %id // &lt;-- PR adds this one
: !nvgpu.warpgroup.accumulator&lt;fragmented = vector&lt;64x16xf32&gt;&gt; 
  to memref&lt;64x16xf32,3&gt;

Full diff: https://github.com/llvm/llvm-project/pull/85820.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+8-4)
  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+3)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+18-12)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+46)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe9..d22b8fd28582c7 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
     The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
     in $matrixD to given memref. 
 
+    Note that, the op must be run with warp group. The operand `warpgroupId` 
+    allow to select the warp group to run the operation. When it is not present,
+    the first warp group runs the operation.
+
     [See the details of register fragment layout for accumulator matrix D]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
-
-    Note that, the op must be run with warp group.
   }];
 
   let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
-                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref,
+                       Optional<I32>:$warpgroupId);
   
   let assemblyFormat = [{
-    $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+    $matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)? 
+    attr-dict `:` type($matrixD) `to` type($dstMemref)
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 19070f6f062a02..0d16b825821252 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
 
 constexpr int kWarpSize = 32;
 
+/// Number of threads in warpgroup
+constexpr int kWarpgroupSize = 128;
+
 /// M size of wgmma.mma_async instruction
 constexpr int kWgmmaSizeM = 64;
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9b5d19ebd783a9..5a8f095b84c791 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
     // // [0,14)   start_address
     dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
 
-    LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
-                      << "leading_off:" << leadDimVal << "\t"
-                      << "stride_off :" << strideDimVal << "\t"
-                      << "base_offset:" << offsetVal << "\t"
+    LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:"
+                      << leadDimVal << "\t" << "stride_off :" << strideDimVal
+                      << "\t" << "base_offset:" << offsetVal << "\t"
                       << "layout_type:" << swizzle << " ("
                       << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
                       << ")\n start_addr :  " << baseAddr << "\n");
@@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering
     /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
     /// descriptors and arranges them based on induction variables: i, j, and k.
     Value generateWgmma(int i, int j, int k, Value matrixC) {
-      LLVM_DEBUG(DBGS() << "\t wgmma."
-                        << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
-                        << "(A[" << (iterationM * wgmmaM) << ":"
+      LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k"
+                        << wgmmaK << "(A[" << (iterationM * wgmmaM) << ":"
                         << (iterationM * wgmmaM) + wgmmaM << "]["
                         << (iterationK * wgmmaK) << ":"
-                        << (iterationK * wgmmaK + wgmmaK) << "] * "
-                        << " B[" << (iterationK * wgmmaK) << ":"
+                        << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+                        << (iterationK * wgmmaK) << ":"
                         << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
                         << wgmmaN << "])\n");
 
@@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
   /// \param offset: the offset within the memref where the registers will be
   /// stored.
   void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
-                             TypedValue<MemRefType> dstMemref,
-                             int offset) const {
+                             TypedValue<MemRefType> dstMemref, int offset,
+                             Value warpgroupId) const {
     Type i32 = b.getI32Type();
 
     auto makeConst = [&](int32_t index) -> Value {
@@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     };
 
     Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
+    // Normalize the thread index to the beginning of the warpgroup
+    if (warpgroupId) {
+      Value s1 =
+          b.create<arith::MulIOp>(warpgroupId, makeConst(kWarpgroupSize));
+      tidx = b.create<arith::SubIOp>(tidx, s1);
+    }
+
     Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
     Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
     Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
@@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
     for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
       auto structType = matrixD.cast<LLVM::LLVMStructType>();
       Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
-      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
+      storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset,
+                            adaptor.getWarpgroupId());
       offset += structType.getBody().size();
     }
     rewriter.eraseOp(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78db..ae9fcd287bdded 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple(
   return 
 }
 
+// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id(  
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32)
+func.func @warpgroup_mma_store_multiple_with_id(
+  %res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
+  %shmem_m64n16k : memref<64x16xf32, 3>, 
+  %id : i32) 
+{    
+  // CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
+  // CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)> 
+  // CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+  // CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32
+  // CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32
+  // CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32
+  // CHECK: %[[s13:.*]] = llvm.urem %12, %8  : i32
+  // CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]]  : i32
+  // CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]]  : i32
+  // CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]]  : i32
+  // CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]]  : i32
+  // CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]]  : i32
+  // CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]]  : i32
+  // CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]]  : i32
+  // CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]]  : i32
+  // CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]]  : i32
+  // CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]]  : i32
+  // CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index
+  // CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index
+  // CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]]  : i32
+  // CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index
+  // CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
+  // CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3>
+  // CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3>
+  // CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3>
+  nvgpu.warpgroup.mma.store  %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32,3>
+  return
+}
+
 func.func @warpgroup_mma_init() {
   //CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
   //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants