Skip to content

[mlir][nvgpu] Select warpgroup id on warpgroup.mma.store #85820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -775,17 +775,21 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
in $matrixD to given memref.

Note that, the op must be run with warp group. The operand `warpgroupId`
allow to select the warp group to run the operation. When it is not present,
the first warp group runs the operation.

[See the details of register fragment layout for accumulator matrix D]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)

Note that, the op must be run with warp group.
}];

let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref,
Optional<I32>:$warpgroupId);

let assemblyFormat = [{
$matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
$matrixD `,` $dstMemref (`,` `warpgroup_id` `=` $warpgroupId^)?
attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

constexpr int kWarpSize = 32;

/// Number of threads in warpgroup
constexpr int kWarpgroupSize = 128;

/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;

Expand Down
30 changes: 18 additions & 12 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1087,10 +1087,9 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);

LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
<< "leading_off:" << leadDimVal << "\t"
<< "stride_off :" << strideDimVal << "\t"
<< "base_offset:" << offsetVal << "\t"
LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " << "leading_off:"
<< leadDimVal << "\t" << "stride_off :" << strideDimVal
<< "\t" << "base_offset:" << offsetVal << "\t"
<< "layout_type:" << swizzle << " ("
<< nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr << "\n");
Expand Down Expand Up @@ -1382,13 +1381,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
LLVM_DEBUG(DBGS() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k"
<< wgmmaK << "(A[" << (iterationM * wgmmaM) << ":"
<< (iterationM * wgmmaM) + wgmmaM << "]["
<< (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "] * "
<< " B[" << (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
<< (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
<< wgmmaN << "])\n");

Expand Down Expand Up @@ -1535,8 +1533,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
/// \param offset: the offset within the memref where the registers will be
/// stored.
void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD,
TypedValue<MemRefType> dstMemref,
int offset) const {
TypedValue<MemRefType> dstMemref, int offset,
Value warpgroupId) const {
Type i32 = b.getI32Type();

auto makeConst = [&](int32_t index) -> Value {
Expand Down Expand Up @@ -1569,6 +1567,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
};

Value tidx = b.create<NVVM::ThreadIdXOp>(i32);
// Normalize the thread index to the beginning of the warpgroup
if (warpgroupId) {
Value s1 =
b.create<arith::MulIOp>(warpgroupId, makeConst(kWarpgroupSize));
tidx = b.create<arith::SubIOp>(tidx, s1);
}

Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize);
Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize);
Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4);
Expand Down Expand Up @@ -1610,7 +1615,8 @@ struct NVGPUWarpgroupMmaStoreOpLowering
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset,
adaptor.getWarpgroupId());
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,52 @@ func.func @warpgroup_mma_store_multiple(
return
}

// CHECK-LABEL: @warpgroup_mma_store_multiple_with_id(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: memref<64x16xf32, 3>, %[[arg2:[a-zA-Z0-9_]+]]: i32)
func.func @warpgroup_mma_store_multiple_with_id(
%res_m64n16k : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>>,
%shmem_m64n16k : memref<64x16xf32, 3>,
%id : i32)
{
// CHECK: %[[s0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[s1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : memref<64x16xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[s3:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[s5:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[s6:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[s7:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[s8:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[s9:.*]] = nvvm.read.ptx.sreg.tid.x : i32
// CHECK: %[[s10:.*]] = llvm.mlir.constant(128 : i32) : i32
// CHECK: %[[s11:.*]] = arith.muli %[[arg2]], %[[s10]] : i32
// CHECK: %[[s12:.*]] = arith.subi %[[s9]], %[[s11]] : i32
// CHECK: %[[s13:.*]] = llvm.urem %12, %8 : i32
// CHECK: %[[s14:.*]] = llvm.udiv %[[s12]], %[[s8]] : i32
// CHECK: %[[s15:.*]] = llvm.udiv %[[s13]], %[[s5]] : i32
// CHECK: %[[s16:.*]] = llvm.urem %[[s13]], %[[s5]] : i32
// CHECK: %[[s17:.*]] = llvm.mul %[[s16]], %[[s4]] : i32
// CHECK: %[[s18:.*]] = llvm.mul %[[s14]], %[[s7]] : i32
// CHECK: %[[s19:.*]] = llvm.add %[[s15]], %[[s18]] : i32
// CHECK: %[[s20:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[s21:.*]] = llvm.mul %[[s20]], %[[s6]] : i32
// CHECK: %[[s22:.*]] = llvm.add %[[s19]], %[[s21]] : i32
// CHECK: %[[s23:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[s24:.*]] = llvm.mul %[[s23]], %[[s6]] : i32
// CHECK: %[[s25:.*]] = llvm.add %[[s17]], %[[s24]] : i32
// CHECK: %[[s26:.*]] = arith.index_cast %[[s22]] : i32 to index
// CHECK: %[[s27:.*]] = arith.index_cast %[[s25]] : i32 to index
// CHECK: %[[s28:.*]] = llvm.add %[[s25]], %[[s3]] : i32
// CHECK: %[[s29:.*]] = arith.index_cast %[[s28]] : i32 to index
// CHECK: %[[s30:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[s31:.*]] = llvm.extractvalue %[[s2]][1] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: memref.store %[[s30]], %[[arg1]][%[[s26]], %[[s27]]] : memref<64x16xf32, 3>
// CHECK: memref.store %[[s31]], %[[arg1]][%[[s26]], %[[s29]]] : memref<64x16xf32, 3>
// CHECK-COUNT-6: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x16xf32, 3>
nvgpu.warpgroup.mma.store %res_m64n16k, %shmem_m64n16k, warpgroup_id = %id : !nvgpu.warpgroup.accumulator<fragmented = vector<64x16xf32>> to memref<64x16xf32,3>
return
}

func.func @warpgroup_mma_init() {
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
//CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
Expand Down