Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Hopper MMA by warp-tile before instruction tile #3642

Merged
merged 21 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
851669a
Split Hopper MMA by warp-tile before instruction tile
jacobhinkle Dec 24, 2024
8b42cd6
Use 4 warpgroups, disable smem epilogue
jacobhinkle Dec 31, 2024
7c6d417
Merge branch 'main' into hopper_warptile_split
jacobhinkle Dec 31, 2024
521d5cc
Use warp_tile for tma_m and tma_n
jacobhinkle Dec 31, 2024
dce16ad
Two warp tiles per CTA in each dim, increase instr to 64_64_16
jacobhinkle Jan 2, 2025
f5e084c
Also split by K
jacobhinkle Jan 2, 2025
be705bf
Add ScheduleWithTranslation test (failing)
jacobhinkle Jan 7, 2025
9de3202
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 8, 2025
41e2b94
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 17, 2025
e010ead
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 28, 2025
5246fb3
Update to fix compilation
jacobhinkle Jan 28, 2025
1dccf22
Don't do K split. Fix TMA offset
jacobhinkle Jan 28, 2025
496d8d7
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 29, 2025
7868900
Unguard most matmul node translation tests on Hopper
jacobhinkle Jan 29, 2025
9b5e73c
lintrunner
jacobhinkle Jan 29, 2025
a52274c
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Feb 3, 2025
db3b93a
Fix busted merge
jacobhinkle Feb 3, 2025
628d849
Apply suggestions from code review
jacobhinkle Feb 5, 2025
73739ed
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Feb 5, 2025
9cbbc2e
Add checks for reduction and non-reduction dims
jacobhinkle Feb 5, 2025
1d4697e
lintrunner
jacobhinkle Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
2};
} else if (ir_utils::isStMatrixOp(ldst)) {
NVF_ERROR(
ldst->out()->as<TensorView>()->getLogicalDomain().size() == 2,
ldst->out()->as<TensorView>()->getLogicalDomain().size() >= 2,
"We only support 2D inputs stmatrix");

NVF_ERROR(
Expand Down
109 changes: 60 additions & 49 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,61 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
NVF_ERROR(
tv->domain()->loop().size() >= 4,
"transformLikeMmaOutput requires at least four iterDomains but ",
tv->toString(),
" only has ",
tv->domain()->loop().size(),
".");

// Original: [..., Mo, No, Mi, Ni]
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
TensorView* tv) {
// The input is originally block tiled so that the inner dims are the CTA tile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any conditions to check here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I guess we should check that the inner dim is reduction at least.

// size
//
// We split this into warp tiles then instruction tiles
// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
// We don't need to split by warp_tile.k, since we always have
// cta_tile.k==warp_tile.k
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
tv->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
tv->reorder({
{-8, -8}, // Mo
{-7, -6}, // Mw
{-6, -3}, // Mi
{-5, -7}, // No
{-4, -5}, // Nw
{-3, -2}, // Ni
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-8);
// After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
tv->axis(-7)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
TensorView* tv) {
// TODO Add constraints
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved

// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N]
// We split this into warp tiles then instruction tiles
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down Expand Up @@ -365,6 +401,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
const std::vector<TensorView*>& smem_operands,
MmaOperand operand_type) {
blockTileTensors(smem_operands);
parallelizeBlocks(smem_operands);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well also parallelize these. Note that we could just call this from blockTileTensors since we are always parallelizing outer dims for every tensor.

for (TensorView* tv : smem_operands) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
Expand Down Expand Up @@ -452,33 +489,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));

// Split k dimension of warp tile only if it is larger than k dimension of
// mma macro. Inlining can be at incorrect position for circular buffering
// if a reduction iterDomain has iterDomain 1.
if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) {
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
} else {
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
mma_result->reorder({{-4, -3}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
mma_result->merge(-5);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-4)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
transformLikeMmaOutputWithK(mma_result);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
Expand Down Expand Up @@ -514,7 +525,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d);
transformLikeMmaOutputWithoutK(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -545,8 +556,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);
const int64_t tma_m = params_->tile_sizes.warp_tile.m;
const int64_t tma_n = params_->tile_sizes.warp_tile.n;

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
Expand Down Expand Up @@ -594,7 +605,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv);
transformLikeMmaOutputWithoutK(tv);
}

// Should not propagate if the dc is a mma output as the mma output has
Expand Down Expand Up @@ -645,7 +656,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum);
transformLikeMmaOutputWithoutK(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down
7 changes: 6 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv);
// This version is meant to be used on the mma_result, which has a Reduction
// K axis.
void transformLikeMmaOutputWithK(TensorView* tv);

// This is like the above method, but tv should not have any K dimension
void transformLikeMmaOutputWithoutK(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;
Expand Down
138 changes: 138 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4519,4 +4519,142 @@ INSTANTIATE_TEST_SUITE_P(
return ss.str();
});

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({K, M, 1}, options);
auto b_ref = at::randn({K, 1, N}, options);
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(256, 256, 32);
gemm_tile.warp_tile = GemmTile(128, 128, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
// conflict.
// TODO: enable smem epilogue once stmatrix is updated
mparams.use_smem_epilogue = false;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Check number of launched threads matches what we expect
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
<< " expected 4 warp groups (BIDy==4) but found BIDy=="
<< ke.lastLaunchParams().bdimy();

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is pretty much identical to the previous one, but it uses a MatmulOp instead of fusedMultiplySum. This is currently failing (passes on main) with

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/pass/circular_buffer.cpp":160, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. No IfThenElse should exist yet:
IF ElectSync:
  MBarrierArriveExpectTx(T9_s[i408] view( T9 ), 4096)
  FOR i372 in iB28{16}:
    FOR i375 in iB34{2}:
      FOR i373 in iB31{4}:
        FOR i376 in iB35{2}:
          FOR i374 in iB33{8}:
            T3_s___half[iblockIdx.x24{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, bS22{1}, iS20{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, bS23{256}, iS26{1}, iB28{16}, iB34{2}, iB31{4}, iB35{2}, iB33{8}] ca_pos( 5 )
               = CpAsyncBulkTensorTile( T0_g___half[iS170{( (( (( getMetaData(T0) )).logical_size ))[0] )}, iS171{( (( (( getMetaData(T0) )).logical_size ))[1] )}] )

Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
// Note tv1 has allocation domain
// tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = matmul(tv0, tv1);

fusion.addOutput(tv2);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({M, K}, options);
// auto b_ref = at::randn({N, K}, options).t();
auto b_ref = at::randn({K, N}, options);
auto out_ref = at::matmul(a_ref, b_ref);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 64, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
mparams.use_smem_epilogue = true;
mparams.cluster_dims = {1, 1, 1};
mparams.promote_prologue_smem_reuse = true;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
kir::Kernel* kernel = ke.compiledKernel()->kernel();
ASSERT_TRUE(kernel != nullptr);
EXPECT_TRUE(getBankConflictInfo(kernel).empty());
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

auto cg_outputs = ke.run(inputs);

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser
7 changes: 6 additions & 1 deletion tests/cpp/test_translate_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,19 @@ using MatmulNodeTranslationTest =
// Test that a simple matmul op fusion is picked up by the appropriate scheduler
// and the translation to MmaOp is performed properly.
TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 10, 0);
const int64_t A_dim = std::get<0>(GetParam());
const int64_t B_dim = std::get<1>(GetParam());
const bool enable_fusion = std::get<2>(GetParam());
const bool transpose_a_alloc = std::get<3>(GetParam());
const bool expect_segmented = std::get<4>(GetParam());
const SchedulerType expected_heuristic = std::get<5>(GetParam());

if (A_dim == 3 && B_dim == 2) {
// TODO: Fix the failure at checkConcreteStaticDim on Hopper in this case
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);
}

// CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled
DisableOptionsGuard dog;
DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval);
Expand Down