-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split Hopper MMA by warp-tile before instruction tile #3642
Changes from 17 commits
851669a
8b42cd6
7c6d417
521d5cc
dce16ad
f5e084c
be705bf
9de3202
41e2b94
e010ead
5246fb3
1dccf22
496d8d7
7868900
9b5e73c
a52274c
db3b93a
628d849
73739ed
9cbbc2e
1d4697e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,25 +29,61 @@ | |
|
||
namespace nvfuser { | ||
|
||
void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) { | ||
NVF_ERROR( | ||
tv->domain()->loop().size() >= 4, | ||
"transformLikeMmaOutput requires at least four iterDomains but ", | ||
tv->toString(), | ||
" only has ", | ||
tv->domain()->loop().size(), | ||
"."); | ||
|
||
// Original: [..., Mo, No, Mi, Ni] | ||
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK( | ||
TensorView* tv) { | ||
// The input is originally block tiled so that the inner dims are the CTA tile | ||
// size | ||
// | ||
// We split this into warp tiles then instruction tiles | ||
// Original: [..., M, N, K] | ||
tv->split(-3, params_->tile_sizes.warp_tile.m); | ||
tv->split(-3, getM(params_->mma_macro)); | ||
tv->split(-2, params_->tile_sizes.warp_tile.n); | ||
tv->split(-2, getN(params_->mma_macro)); | ||
// K dimension is present for mma_result | ||
// We don't need to split by warp_tile.k, since we always have | ||
// cta_tile.k==warp_tile.k | ||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tv->split(-1, getK(params_->mma_macro)); | ||
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki] | ||
tv->reorder({ | ||
{-8, -8}, // Mo | ||
{-7, -6}, // Mw | ||
{-6, -3}, // Mi | ||
{-5, -7}, // No | ||
{-4, -5}, // Nw | ||
{-3, -2}, // Ni | ||
{-2, -4}, // Kw | ||
{-1, -1}, // Ki | ||
}); | ||
// After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki] | ||
tv->merge(-8); | ||
// After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni] | ||
tv->axis(-7)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki] | ||
} | ||
|
||
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK( | ||
TensorView* tv) { | ||
// TODO Add constraints | ||
jacobhinkle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// The input is originally block tiled so that the inner dims are the CTA tile | ||
// size | ||
// Original: [..., M, N] | ||
// We split this into warp tiles then instruction tiles | ||
tv->split(-2, params_->tile_sizes.warp_tile.m); | ||
tv->split(-2, getM(params_->mma_macro)); | ||
tv->split(-1, params_->tile_sizes.warp_tile.n); | ||
tv->split(-1, getN(params_->mma_macro)); | ||
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii] | ||
tv->reorder({{-3, -2}}); | ||
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] | ||
tv->merge(-4); | ||
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] | ||
tv->axis(-3)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] | ||
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni] | ||
tv->reorder({ | ||
{-3, -5}, | ||
{-2, -3}, | ||
}); | ||
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] | ||
tv->merge(-6); | ||
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] | ||
tv->axis(-5)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] | ||
} | ||
|
||
MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { | ||
|
@@ -365,6 +401,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() { | |
const std::vector<TensorView*>& smem_operands, | ||
MmaOperand operand_type) { | ||
blockTileTensors(smem_operands); | ||
parallelizeBlocks(smem_operands); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might as well also parallelize these. Note that we could just call this from |
||
for (TensorView* tv : smem_operands) { | ||
if (params_->promote_prologue_smem_reuse) { | ||
tv->promoteReuse(); | ||
|
@@ -452,33 +489,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { | |
splitk_sums_.push_back(splitk_sum); | ||
} | ||
|
||
// Original: [..., Mo, No, Mi, Ni, Ki] | ||
mma_result->split(-3, getM(params_->mma_macro)); | ||
mma_result->split(-2, getN(params_->mma_macro)); | ||
|
||
// Split k dimension of warp tile only if it is larger than k dimension of | ||
// mma macro. Inlining can be at incorrect position for circular buffering | ||
// if a reduction iterDomain has iterDomain 1. | ||
if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) { | ||
mma_result->split(-1, getK(params_->mma_macro)); | ||
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii] | ||
mma_result->reorder({{-5, -4}}); | ||
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii] | ||
mma_result->reorder({{-2, -4}}); | ||
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii] | ||
mma_result->merge(-6); | ||
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] | ||
mma_result->axis(-5)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] | ||
} else { | ||
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii] | ||
mma_result->reorder({{-4, -3}}); | ||
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] | ||
mma_result->merge(-5); | ||
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] | ||
mma_result->axis(-4)->parallelize(ParallelType::TIDy); | ||
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] | ||
} | ||
transformLikeMmaOutputWithK(mma_result); | ||
|
||
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( | ||
mma_result->getLoopDomain()); | ||
|
@@ -514,7 +525,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { | |
// op. | ||
blockTileTensors({d}); | ||
parallelizeBlocks({d}); | ||
transformLikeMmaOutput(d); | ||
transformLikeMmaOutputWithoutK(d); | ||
|
||
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( | ||
d->getLoopDomain()); | ||
|
@@ -545,8 +556,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { | |
// tile is a multiple of the macro size because stmatrix stores results from | ||
// wgmma to shared memory. For maximum inlining and to reduce shared memory | ||
// usage, the tma tile is mma_macro size. | ||
const int64_t tma_m = getM(params_->mma_macro); | ||
const int64_t tma_n = getN(params_->mma_macro); | ||
const int64_t tma_m = params_->tile_sizes.warp_tile.m; | ||
const int64_t tma_n = params_->tile_sizes.warp_tile.n; | ||
|
||
fusion_->manage("st_matrix_m_tile", stmatrix_tile_m); | ||
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n); | ||
|
@@ -594,7 +605,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { | |
blockTileTensors(tvs_to_schedule); | ||
parallelizeBlocks(tvs_to_schedule); | ||
for (auto tv : tvs_to_schedule) { | ||
transformLikeMmaOutput(tv); | ||
transformLikeMmaOutputWithoutK(tv); | ||
} | ||
|
||
// Should not propagate if the dc is a mma output as the mma output has | ||
|
@@ -645,7 +656,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { | |
for (TensorView* splitk_sum : splitk_sums_) { | ||
// Always use serial grid reduction for split-K sum | ||
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction(); | ||
transformLikeMmaOutput(splitk_sum); | ||
transformLikeMmaOutputWithoutK(splitk_sum); | ||
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( | ||
splitk_sum->getLoopDomain()); | ||
splitk_sum->setLoopDomain(s.as<IterDomain*>()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4519,4 +4519,142 @@ INSTANTIATE_TEST_SUITE_P( | |
return ss.str(); | ||
}); | ||
|
||
// This tests that we can use a small instruction tile with a medium size | ||
// warpgroup tile and a large CTA tile. | ||
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { | ||
Fusion fusion; | ||
FusionGuard fg(&fusion); | ||
|
||
constexpr int64_t M = 2048, N = 2048, K = 8192; | ||
const auto dtype = DataType::Half; | ||
|
||
auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M | ||
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N | ||
fusion.addInput(tv0); | ||
fusion.addInput(tv1); | ||
|
||
auto tv2 = fusedMultiplySum(tv0, tv1, {0}); | ||
|
||
// Reorder the accumulator as [M, N, K] | ||
// [K, M, N] -> [M, N, K] | ||
tv2->reorder({{-3, -1}}); | ||
tv2->commitLeafToLogical(); | ||
|
||
auto tv3 = castOp(DataType::Half, tv2); | ||
fusion.addOutput(tv3); | ||
|
||
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); | ||
auto a_ref = at::randn({K, M, 1}, options); | ||
auto b_ref = at::randn({K, 1, N}, options); | ||
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); | ||
|
||
MatMulTileOptions gemm_tile; | ||
// Regardless of the instruction, this should result in 2 warp groups i.e. 256 | ||
// threads | ||
gemm_tile.cta_tile = GemmTile(256, 256, 32); | ||
gemm_tile.warp_tile = GemmTile(128, 128, 32); | ||
|
||
MatmulParams mparams; | ||
mparams.supported_vec_size = {8, 8, 8}; | ||
mparams.mma_macro = MmaMacro::Hopper_64_64_16; | ||
mparams.tile_sizes = gemm_tile; | ||
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; | ||
mparams.async_gmem_load_operands = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_write = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_read = false; | ||
mparams.circular_buffer_options.smem_circular_buffer_stage = 4; | ||
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; | ||
mparams.splitk_factor = 1; | ||
// NOTE: disabling smem use for this test since we currrently hit a bank | ||
// conflict. | ||
// TODO: enable smem epilogue once stmatrix is updated | ||
mparams.use_smem_epilogue = false; | ||
mparams.cluster_dims = {2, 1, 1}; | ||
mparams.promote_prologue_smem_reuse = false; | ||
|
||
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) | ||
->schedule(&fusion, &mparams); | ||
|
||
std::vector<c10::IValue> inputs = {a_ref, b_ref}; | ||
|
||
KernelExecutor ke; | ||
ke.compile(&fusion, inputs); | ||
kir::Kernel* kernel = ke.compiledKernel()->kernel(); | ||
ASSERT_TRUE(kernel != nullptr); | ||
EXPECT_TRUE(getBankConflictInfo(kernel).empty()); | ||
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); | ||
|
||
auto cg_outputs = ke.run(inputs); | ||
|
||
// Check number of launched threads matches what we expect | ||
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128); | ||
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4) | ||
<< " expected 4 warp groups (BIDy==4) but found BIDy==" | ||
<< ke.lastLaunchParams().bdimy(); | ||
|
||
// Relax tolerance for larger sum due to large K | ||
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
TEST_F(HopperMatmulTest, ScheduleWithTranslation) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is pretty much identical to the previous one, but it uses a
|
||
Fusion fusion; | ||
FusionGuard fg(&fusion); | ||
|
||
constexpr int64_t M = 2048, N = 2048, K = 8192; | ||
const auto dtype = DataType::Half; | ||
|
||
auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K | ||
auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N | ||
// Note tv1 has allocation domain | ||
// tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true); | ||
fusion.addInput(tv0); | ||
fusion.addInput(tv1); | ||
|
||
auto tv2 = matmul(tv0, tv1); | ||
|
||
fusion.addOutput(tv2); | ||
|
||
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); | ||
auto a_ref = at::randn({M, K}, options); | ||
// auto b_ref = at::randn({N, K}, options).t(); | ||
auto b_ref = at::randn({K, N}, options); | ||
auto out_ref = at::matmul(a_ref, b_ref); | ||
|
||
MatMulTileOptions gemm_tile; | ||
gemm_tile.cta_tile = GemmTile(128, 256, 16); | ||
gemm_tile.warp_tile = GemmTile(64, 64, 16); | ||
|
||
MatmulParams mparams; | ||
mparams.supported_vec_size = {8, 8, 8}; | ||
mparams.mma_macro = MmaMacro::Hopper_64_64_16; | ||
mparams.tile_sizes = gemm_tile; | ||
mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor; | ||
mparams.async_gmem_load_operands = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_write = true; | ||
mparams.circular_buffer_options.circular_buffer_smem_read = false; | ||
mparams.circular_buffer_options.smem_circular_buffer_stage = 3; | ||
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; | ||
mparams.splitk_factor = 1; | ||
mparams.use_smem_epilogue = true; | ||
mparams.cluster_dims = {1, 1, 1}; | ||
mparams.promote_prologue_smem_reuse = true; | ||
|
||
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) | ||
->schedule(&fusion, &mparams); | ||
|
||
std::vector<c10::IValue> inputs = {a_ref, b_ref}; | ||
|
||
KernelExecutor ke; | ||
ke.compile(&fusion, inputs); | ||
kir::Kernel* kernel = ke.compiledKernel()->kernel(); | ||
ASSERT_TRUE(kernel != nullptr); | ||
EXPECT_TRUE(getBankConflictInfo(kernel).empty()); | ||
EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); | ||
|
||
auto cg_outputs = ke.run(inputs); | ||
|
||
// Relax tolerance for larger sum due to large K | ||
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any conditions to check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I guess we should check that the inner dim is reduction at least.