Skip to content

Commit

Permalink
split k dim if warp tile is larger than macro
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Jan 31, 2025
1 parent 919eeb1 commit 74f2056
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,30 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]

// Split k dimension of warp tile only if it is larger than k dimension of
// mma macro. Inlining can be at incorrect position for circular buffering
// if a reduction iterDomain has iterDomain 1.
if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) {
mma_result->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
mma_result->reorder({{-5, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
mma_result->reorder({{-2, -4}});
// After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
mma_result->merge(-6);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
} else {
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
mma_result->reorder({{-4, -3}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
mma_result->merge(-5);
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
mma_result->axis(-4)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
Expand Down

0 comments on commit 74f2056

Please sign in to comment.