Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue multiple wgmma operations when CTA k dim is a multiple of 16 #3616

Merged
merged 9 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class AllocationInserter : public kir::ExprMutator {
// generic-async proxy fence and wgmma fence before each mma
// instruction. For this case, we need to insert these fences
// after the initialization of the accumulator, so that the
// inilization is visible to the async proxy.
// initialization is visible to the async proxy.
// When all inputs are guarded by mbarrier, we will insert these
// fences before each mma instruction, so there is no need to
// insert them after the initialization of the accumulator here.
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
}
};

// Insert wait expressions for WAR harzard for async operations such as wgmma
// Insert wait expressions for WAR hazard for async operations such as wgmma
// and tma store. To do so, we find the structure like the following example:
// for 1
// for 2
Expand Down Expand Up @@ -969,7 +969,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator {
// that consumes the circular buffered tensor, the "pending_ops" can be larger
// than 0, depending on the prefetch distance and the stage depth of the
// circular buffer loop. When the prefetch distance is smaller than
// stage_depth - 1, we have have buffers for eliminating WAR harzards, so we
// stage_depth - 1, we have have buffers for eliminating WAR hazards, so we
// can allow more pending transactions.
int64_t getPendingOpsFor(Expr* expr, ForLoop* current_loop) {
auto for_loops_including_current = for_loops_;
Expand Down
47 changes: 30 additions & 17 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,24 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
TensorView* tv,
bool is_mma_result) {
// TODO Add constraints

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};
void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
NVF_ERROR(
tv->domain()->loop().size() >= 4,
"transformLikeMmaOutput requires at least four iterDomains but ",
tv->toString(),
" only has ",
tv->domain()->loop().size(),
".");

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
tv->reorder({{-3, -2}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
tv->merge(-4);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
tv->axis(-3)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}

Expand Down Expand Up @@ -424,7 +424,20 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

transformLikeMmaOutput(mma_result, /*is_mma_result=*/true);
// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
Expand Down Expand Up @@ -459,7 +472,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);
transformLikeMmaOutput(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -539,7 +552,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
transformLikeMmaOutput(tv);
}

// Should not propagate if the dc is a mma output as the mma output has
Expand Down Expand Up @@ -590,7 +603,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false);
transformLikeMmaOutput(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv, bool is_mma_result);
void transformLikeMmaOutput(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;
Expand Down
16 changes: 8 additions & 8 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4028,8 +4028,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4085,8 +4085,8 @@ TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4148,8 +4148,8 @@ TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) {
at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down Expand Up @@ -4210,8 +4210,8 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) {
auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 256, 16);
gemm_tile.cta_tile = GemmTile(128, 256, 32);
gemm_tile.warp_tile = GemmTile(64, 256, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3329,13 +3329,13 @@ class HopperMatmulSchedulerTest
// TODO cta tile is a multiple of mma macro for hopper.
// Default cta_tile configuration is 2-CTA.
gemm_tile.cta_tile =
GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro));
GemmTile(2 * getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro));

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile =
GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro));
GemmTile(getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro));

mparams.supported_vec_size = {8, 8, 4};
mparams.supported_vec_size = {8, 8, 8};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, although I think it currently has no meaning until we start handling epilogue inputs with supported vec size.


mparams.mma_macro = mma_macro;

Expand Down Expand Up @@ -3523,7 +3523,7 @@ INSTANTIATE_TEST_SUITE_P(
testing::Bool(), // b_k_inner
testing::Values(512), // M
testing::Values(256), // N
testing::Values(64), // K
testing::Values(128), // K
testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros
testing::Values(1, 2) // SplitK Factor
),
Expand Down