Skip to content

Commit

Permalink
Split Hopper MMA by warp-tile before instruction tile (#3642)
Browse files Browse the repository at this point in the history
Currently we ignore the warp tile parameter when scheduling Hopper
matmuls (see #3636). This PR introduces a test with different CTA, warp,
and instruction tiles and modifies the Hopper scheduler to split by warp
tile in addition to instruction tile. Note that the instruction tile
split results in two serial loop domain so we wind up executing multiple
mma instructions in each main loop. In the included example, `warp_tile`
is 64, 128, 16 and the macro is `Hopper_64_8_16`. In this case, there
are 128/8 = 16 instruction tiles per warp tile so the generated main
loop looks like this:
```c++
  #pragma unroll 3
  for(nvfuser_index_t i33 = 0; i33 < i4; ++i33) {
    nvfuser_index_t i34;
    i34 = 48 + (16 * i33);
    nvfuser_index_t i35;
    i35 = (3 + i33) % 4;
    unsigned i36;
    i36 = i7 + (8192 * i35);
    unsigned i37;
    i37 = i10 + (4096 * i35);
    nvfuser_index_t i38;
    i38 = i33 % 4;
    unsigned i39;
    i39 = i13 + (4096 * i38);
    uint64_t i40;
    i40 = 4611686293305294848ULL | ((262143ULL & (uint64_t)(i39)) >> 4ULL);
    unsigned i41;
    i41 = i15 + (8192 * i38);
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i31 = 0; i31 < 4; ++i31) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i31)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i36 + (2048 * i31)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 4096U);
      #pragma unroll
      for(nvfuser_index_t i32 = 0; i32 < 2; ++i32) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i32)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i37 + (2048 * i32)));
      }
    }
    mbarrier::waitParity(toSmem((&T8[(i33 % 4)])), (uint32_t)(((i33 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) {
      unsigned i42;
      i42 = (i41 + (2048 * (i25 / 8))) + (16 * (i25 % 8));
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %6, 0;\n"
        "  wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 {%0, %1, %2, %3}, %4, %5, p0, %7, %8, %9, %10;\n"
        "}\n"
        :"+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[0]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[1]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[2]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[3])
        :"l"(i40),
         "l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i42)) >> 4ULL))),
         "n"((uint32_t)(true)),
         "n"(1),
         "n"(1),
         "n"(1),
         "n"(1)
      );
    }
    __syncthreads();
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  }
```

Fixes #3636

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
jacobhinkle and rdspring1 authored Feb 6, 2025
1 parent 5bfaa0e commit b076a55
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 45 deletions.
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
2};
} else if (ir_utils::isStMatrixOp(ldst)) {
NVF_ERROR(
ldst->out()->as<TensorView>()->getLogicalDomain().size() == 2,
ldst->out()->as<TensorView>()->getLogicalDomain().size() >= 2,
"We only support 2D inputs stmatrix");

NVF_ERROR(
Expand Down
104 changes: 62 additions & 42 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,70 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
TensorView* tv) {
NVF_ERROR(tv->axis(-1)->isReduction(), "Inner axis should be Reduction.");
// The input is originally block tiled so that the inner dims are the CTA tile
// size
//
// We split this into warp tiles then instruction tiles
// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
// We don't need to split by warp_tile.k, since we always have
// cta_tile.k == warp_tile.k
tv->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
tv->reorder({
{-8, -8}, // Mo
{-7, -6}, // Mw
{-6, -3}, // Mi
{-5, -7}, // No
{-4, -5}, // Nw
{-3, -2}, // Ni
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-8);
// After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
tv->axis(-7)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
TensorView* tv) {
NVF_ERROR(
tv->domain()->loop().size() >= 4,
"transformLikeMmaOutput requires at least four iterDomains but ",
"transformLikeMmaOutputWithoutK requires at least four iterDomains but ",
tv->toString(),
" only has ",
tv->domain()->loop().size(),
".");
NVF_ERROR(
!tv->axis(-1)->isReduction(), "Inner axis should not be Reduction.");

// Original: [..., Mo, No, Mi, Ni]
// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N]
// We split this into warp tiles then instruction tiles
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down Expand Up @@ -365,6 +410,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
const std::vector<TensorView*>& smem_operands,
MmaOperand operand_type) {
blockTileTensors(smem_operands);
parallelizeBlocks(smem_operands);
for (TensorView* tv : smem_operands) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
Expand Down Expand Up @@ -452,33 +498,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));

// Split k dimension of warp tile only if it is larger than k dimension of
// mma macro. Inlining can be at incorrect position for circular buffering
// if a reduction iterDomain has iterDomain 1.
if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) {
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
} else {
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
mma_result->reorder({{-4, -3}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
mma_result->merge(-5);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-4)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
transformLikeMmaOutputWithK(mma_result);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
Expand Down Expand Up @@ -514,7 +534,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d);
transformLikeMmaOutputWithoutK(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -545,8 +565,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);
const int64_t tma_m = params_->tile_sizes.warp_tile.m;
const int64_t tma_n = params_->tile_sizes.warp_tile.n;

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
Expand Down Expand Up @@ -594,7 +614,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv);
transformLikeMmaOutputWithoutK(tv);
}

// Should not propagate if the dc is a mma output as the mma output has
Expand Down Expand Up @@ -645,7 +665,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum);
transformLikeMmaOutputWithoutK(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down
7 changes: 6 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv);
// This version is meant to be used on the mma_result, which has a Reduction
// K axis.
void transformLikeMmaOutputWithK(TensorView* tv);

// This is like the above method, but tv should not have any K dimension
void transformLikeMmaOutputWithoutK(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;
Expand Down
138 changes: 138 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4519,4 +4519,142 @@ INSTANTIATE_TEST_SUITE_P(
return ss.str();
});

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({K, M, 1}, options);
auto b_ref = at::randn({K, 1, N}, options);
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(256, 256, 32);
gemm_tile.warp_tile = GemmTile(128, 128, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
// conflict.
// TODO: enable smem epilogue once stmatrix is updated
mparams.use_smem_epilogue = false;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Check number of launched threads matches what we expect
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
<< " expected 4 warp groups (BIDy==4) but found BIDy=="
<< ke.lastLaunchParams().bdimy();

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
// Note tv1 has allocation domain
// tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1);

fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({M, K}, options);
// auto b_ref = at::randn({N, K}, options).t();
auto b_ref = at::randn({K, N}, options);
auto out_ref = at::matmul(a_ref, b_ref);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 64, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {1, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser
7 changes: 6 additions & 1 deletion tests/cpp/test_translate_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,19 @@ using MatmulNodeTranslationTest =
// Test that a simple matmul op fusion is picked up by the appropriate scheduler
// and the translation to MmaOp is performed properly.
TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 10, 0);
const int64_t A_dim = std::get<0>(GetParam());
const int64_t B_dim = std::get<1>(GetParam());
const bool enable_fusion = std::get<2>(GetParam());
const bool transpose_a_alloc = std::get<3>(GetParam());
const bool expect_segmented = std::get<4>(GetParam());
const SchedulerType expected_heuristic = std::get<5>(GetParam());

if (A_dim == 3 && B_dim == 2) {
// TODO: Fix the failure at checkConcreteStaticDim on Hopper in this case
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
}

// CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled
DisableOptionsGuard dog;
DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval);
Expand Down

0 comments on commit b076a55

Please sign in to comment.