Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/hopper_warptile_split' into jh/p…
Browse files Browse the repository at this point in the history
…ersistent_kernel_impl
  • Loading branch information
jacobhinkle committed Jan 31, 2025
2 parents 95cf199 + 9b5e73c commit ffa276e
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 26 deletions.
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
2};
} else if (ir_utils::isStMatrixOp(ldst)) {
NVF_ERROR(
ldst->out()->as<TensorView>()->getLogicalDomain().size() == 2,
ldst->out()->as<TensorView>()->getLogicalDomain().size() >= 2,
"We only support 2D inputs stmatrix");

NVF_ERROR(
Expand Down
83 changes: 60 additions & 23 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,61 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
TensorView* tv,
bool is_mma_result) {
// TODO Add constraints
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
TensorView* tv) {
// The input is originally block tiled so that the inner dims are the CTA tile
// size
//
// We split this into warp tiles then instruction tiles
// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
// We don't need to split by warp_tile.k, since we always have
// cta_tile.k==warp_tile.k
tv->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
tv->reorder({
{-8, -8}, // Mo
{-7, -6}, // Mw
{-6, -3}, // Mi
{-5, -7}, // No
{-4, -5}, // Nw
{-3, -2}, // Ni
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-8);
// After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
tv->axis(-7)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
TensorView* tv) {
// TODO Add constraints

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N]
// We split this into warp tiles then instruction tiles
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down Expand Up @@ -365,6 +401,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
const std::vector<TensorView*>& smem_operands,
MmaOperand operand_type) {
blockTileTensors(smem_operands);
parallelizeBlocks(smem_operands);
for (TensorView* tv : smem_operands) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
Expand Down Expand Up @@ -452,7 +489,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

transformLikeMmaOutput(mma_result, /*is_mma_result=*/true);
transformLikeMmaOutputWithK(mma_result);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
Expand Down Expand Up @@ -487,7 +524,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -518,8 +555,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);
const int64_t tma_m = params_->tile_sizes.warp_tile.m;
const int64_t tma_n = params_->tile_sizes.warp_tile.n;

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
Expand Down Expand Up @@ -567,7 +604,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(tv);
}

// Should not propagate if the dc is a mma output as the mma output has
Expand Down Expand Up @@ -618,7 +655,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down
7 changes: 6 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv, bool is_mma_result);
// This version is meant to be used on the mma_result, which has a Reduction
// K axis.
void transformLikeMmaOutputWithK(TensorView* tv);

// This is like the above method, but tv should not have any K dimension
void transformLikeMmaOutputWithoutK(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;
Expand Down
138 changes: 138 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4518,4 +4518,142 @@ INSTANTIATE_TEST_SUITE_P(
return ss.str();
});

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({K, M, 1}, options);
auto b_ref = at::randn({K, 1, N}, options);
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(256, 256, 32);
gemm_tile.warp_tile = GemmTile(128, 128, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
// conflict.
// TODO: enable smem epilogue once stmatrix is updated
mparams.use_smem_epilogue = false;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Check number of launched threads matches what we expect
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
<< " expected 4 warp groups (BIDy==4) but found BIDy=="
<< ke.lastLaunchParams().bdimy();

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
// Note tv1 has allocation domain
// tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1);

fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({M, K}, options);
// auto b_ref = at::randn({N, K}, options).t();
auto b_ref = at::randn({K, N}, options);
auto out_ref = at::matmul(a_ref, b_ref);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 64, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {1, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser
7 changes: 6 additions & 1 deletion tests/cpp/test_translate_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,19 @@ using MatmulNodeTranslationTest =
// Test that a simple matmul op fusion is picked up by the appropriate scheduler
// and the translation to MmaOp is performed properly.
TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 10, 0);
const int64_t A_dim = std::get<0>(GetParam());
const int64_t B_dim = std::get<1>(GetParam());
const bool enable_fusion = std::get<2>(GetParam());
const bool transpose_a_alloc = std::get<3>(GetParam());
const bool expect_segmented = std::get<4>(GetParam());
const SchedulerType expected_heuristic = std::get<5>(GetParam());

if (A_dim == 3 && B_dim == 2) {
// TODO: Fix the failure at checkConcreteStaticDim on Hopper in this case
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
}

// CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled
DisableOptionsGuard dog;
DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval);
Expand Down

0 comments on commit ffa276e

Please sign in to comment.