Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement persistent matmul scheduling #3812

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
851669a
Split Hopper MMA by warp-tile before instruction tile
jacobhinkle Dec 24, 2024
8b42cd6
Use 4 warpgroups, disable smem epilogue
jacobhinkle Dec 31, 2024
7c6d417
Merge branch 'main' into hopper_warptile_split
jacobhinkle Dec 31, 2024
521d5cc
Use warp_tile for tma_m and tma_n
jacobhinkle Dec 31, 2024
dce16ad
Two warp tiles per CTA in each dim, increase instr to 64_64_16
jacobhinkle Jan 2, 2025
f5e084c
Also split by K
jacobhinkle Jan 2, 2025
be705bf
Add ScheduleWithTranslation test (failing)
jacobhinkle Jan 7, 2025
9de3202
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 8, 2025
41e2b94
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 17, 2025
e010ead
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 28, 2025
5246fb3
Update to fix compilation
jacobhinkle Jan 28, 2025
1dccf22
Don't do K split. Fix TMA offset
jacobhinkle Jan 28, 2025
496d8d7
Merge remote-tracking branch 'origin/main' into hopper_warptile_split
jacobhinkle Jan 29, 2025
dfa6ff8
Add options for warp specialization and persistence strategy
jacobhinkle Jan 29, 2025
21c508d
Temporarily revert change to scheduleStMatrixForMmaOutput
jacobhinkle Jan 29, 2025
174deda
Parametrize MLP Benchmark tests to run three configurations
jacobhinkle Jan 29, 2025
7868900
Unguard most matmul node translation tests on Hopper
jacobhinkle Jan 29, 2025
9b5e73c
lintrunner
jacobhinkle Jan 29, 2025
21a2710
Apply suggestions from code review
jacobhinkle Jan 29, 2025
f1fff43
Reparametrize and place a big comment explaining
jacobhinkle Jan 30, 2025
a3b8fd4
Update python bindings
jacobhinkle Jan 31, 2025
bfd65f3
Add more checks for valid configs
jacobhinkle Jan 31, 2025
694e0fe
Set warp specialization as default on hopper
jacobhinkle Jan 31, 2025
794285b
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Jan 31, 2025
95cf199
Guard MLPBenchmarkTest to Hopper only
jacobhinkle Jan 31, 2025
ffa276e
Merge remote-tracking branch 'origin/hopper_warptile_split' into jh/p…
jacobhinkle Jan 31, 2025
86d75de
Merge in from #3642. Add persistent change
jacobhinkle Jan 31, 2025
6d98405
Add BroadcastInputs tests
jacobhinkle Feb 3, 2025
438e1a0
Remove debug prints
jacobhinkle Feb 3, 2025
68c07a0
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Feb 3, 2025
4d0226c
Fix block parallelization
jacobhinkle Feb 3, 2025
07c93c6
Override params for horizontal fusion tests
jacobhinkle Feb 6, 2025
37a7282
Merge commit '9dc94c0' into jh/persistent_kernel_impl
jacobhinkle Feb 6, 2025
74751b3
Merge commit '3ac19f0' into jh/persistent_kernel_impl
jacobhinkle Feb 6, 2025
2527dc0
Merge commit 'a1baafa' into jh/persistent_kernel_impl
jacobhinkle Feb 6, 2025
e4486c8
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Feb 6, 2025
b0359a2
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Feb 6, 2025
7f161bc
Uncomment correctness checks in tests
jacobhinkle Feb 6, 2025
0cbb3e6
Guard failing MLPBenchmarkTest cases on Ampere
jacobhinkle Feb 6, 2025
dec223e
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Feb 11, 2025
913ba63
Don't do register sharing for OneTilePerCTA
jacobhinkle Feb 11, 2025
4ef0339
Merge remote-tracking branch 'origin/main' into jh/persistent_kernel_…
jacobhinkle Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
2};
} else if (ir_utils::isStMatrixOp(ldst)) {
NVF_ERROR(
ldst->out()->as<TensorView>()->getLogicalDomain().size() == 2,
ldst->out()->as<TensorView>()->getLogicalDomain().size() >= 2,
"We only support 2D inputs stmatrix");

NVF_ERROR(
Expand Down
140 changes: 101 additions & 39 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <scheduler/debug_utils.h>
#include <scheduler/hopper_multi_matmul.h>
#include <scheduler/matmul.h>
#include <scheduler/matmul_heuristic.h>
#include <scheduler/matmul_utils.h>
#include <scheduler/mma_utils.h>
#include <scheduler/tools/abstract_tensor.h>
Expand All @@ -29,25 +30,61 @@

namespace nvfuser {

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
TensorView* tv,
bool is_mma_result) {
// TODO Add constraints
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
TensorView* tv) {
// The input is originally block tiled so that the inner dims are the CTA tile
// size
//
// We split this into warp tiles then instruction tiles
// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
// We don't need to split by warp_tile.k, since we always have
// cta_tile.k==warp_tile.k
tv->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
tv->reorder({
{-8, -8}, // Mo
{-7, -6}, // Mw
{-6, -3}, // Mi
{-5, -7}, // No
{-4, -5}, // Nw
{-3, -2}, // Ni
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-8);
// After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
tv->axis(-7)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};
void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
TensorView* tv) {
// TODO Add constraints

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N]
// We split this into warp tiles then instruction tiles
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand All @@ -69,11 +106,6 @@ void HopperMultipleMatmulScheduler::validate() const {
"Hopper matmul scheduler does not support scheduling persistent split-K kernels");
}

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs,
"Hopper matmul scheduler TEMPORARILY does not support persistent scheduling of tiles yet");

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
Expand Down Expand Up @@ -338,6 +370,21 @@ std::vector<std::vector<MatmulDimRole>> HopperMultipleMatmulScheduler::
}
}
}

if (params_->tiling_strategy ==
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs) {
// Persistent kernel scheduling
if (params_->cta_order ==
MatmulParams::TileRasterizationOrder::ColumnMajor) {
tv->reorder(
{{num_device_and_batch_dims_, num_device_and_batch_dims_ + 1}});
}
tv->merge(num_device_and_batch_dims_, num_device_and_batch_dims_ + 1);

const int64_t num_sms =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
tv->split(num_device_and_batch_dims_, num_sms);
}
}
return all_merged_roles;
}
Expand Down Expand Up @@ -365,6 +412,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
const std::vector<TensorView*>& smem_operands,
MmaOperand operand_type) {
blockTileTensors(smem_operands);
parallelizeBlocks(smem_operands);
for (TensorView* tv : smem_operands) {
if (params_->promote_prologue_smem_reuse) {
tv->promoteReuse();
Expand All @@ -381,21 +429,35 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
void HopperMultipleMatmulScheduler::parallelizeBlocks(
const std::vector<TensorView*>& tvs) const {
for (TensorView* tv : tvs) {
switch (params_->cta_order) {
// TODO: Should we instead check the roles of these dimensions to take the
// outermost two M or N axes?
case MatmulParams::TileRasterizationOrder::RowMajor:
tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDx);
tv->axis(num_device_and_batch_dims_ + 1)
->parallelize(ParallelType::BIDy);
switch (params_->tiling_strategy) {
case MatmulParams::TilingStrategy::OneTilePerCTA:
// Data-parallel kernels are parallelized BIDx BIDy
switch (params_->cta_order) {
// TODO: Should we instead check the roles of these dimensions to take
// the outermost two M or N axes?
case MatmulParams::TileRasterizationOrder::RowMajor:
tv->axis(num_device_and_batch_dims_)
->parallelize(ParallelType::BIDx);
tv->axis(num_device_and_batch_dims_ + 1)
->parallelize(ParallelType::BIDy);
break;
case MatmulParams::TileRasterizationOrder::ColumnMajor:
tv->axis(num_device_and_batch_dims_)
->parallelize(ParallelType::BIDy);
tv->axis(num_device_and_batch_dims_ + 1)
->parallelize(ParallelType::BIDx);
break;
default:
NVF_THROW(
"Invalid TileRasterizationOrder passed to Matmul scheduler");
}
break;
case MatmulParams::TileRasterizationOrder::ColumnMajor:
tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDy);
case MatmulParams::TilingStrategy::DistributeTilesAcrossSMs:
case MatmulParams::TilingStrategy::DistributeStagesAcrossSMs:
// For persistent kernels, we just parallelize the SM dimension
tv->axis(num_device_and_batch_dims_ + 1)
->parallelize(ParallelType::BIDx);
break;
default:
NVF_THROW("Invalid TileRasterizationOrder passed to Matmul scheduler");
}
}
}
Expand Down Expand Up @@ -452,7 +514,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() {
splitk_sums_.push_back(splitk_sum);
}

transformLikeMmaOutput(mma_result, /*is_mma_result=*/true);
transformLikeMmaOutputWithK(mma_result);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
mma_result->getLoopDomain());
mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
Expand Down Expand Up @@ -487,7 +549,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// op.
blockTileTensors({d});
parallelizeBlocks({d});
transformLikeMmaOutput(d, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(d);

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d->getLoopDomain());
Expand Down Expand Up @@ -518,8 +580,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);
const int64_t tma_m = params_->tile_sizes.warp_tile.m;
const int64_t tma_n = params_->tile_sizes.warp_tile.n;

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
Expand Down Expand Up @@ -567,7 +629,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
blockTileTensors(tvs_to_schedule);
parallelizeBlocks(tvs_to_schedule);
for (auto tv : tvs_to_schedule) {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(tv);
}

// Should not propagate if the dc is a mma output as the mma output has
Expand Down Expand Up @@ -618,7 +680,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
for (TensorView* splitk_sum : splitk_sums_) {
// Always use serial grid reduction for split-K sum
splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false);
transformLikeMmaOutputWithoutK(splitk_sum);
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
splitk_sum->getLoopDomain());
splitk_sum->setLoopDomain(s.as<IterDomain*>());
Expand Down
7 changes: 6 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
// Schedule a block-tiled TensorView like mma output.
// Why? WGMMA has a unique output format. TensorViews after the mma-result in
// registers must respect this format for correctness.
void transformLikeMmaOutput(TensorView* tv, bool is_mma_result);
// This version is meant to be used on the mma_result, which has a Reduction
// K axis.
void transformLikeMmaOutputWithK(TensorView* tv);

// This is like the above method, but tv should not have any K dimension
void transformLikeMmaOutputWithoutK(TensorView* tv);

private:
std::vector<ValGroup> canonical_dim_ordering_;
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ bool fillDefaultHopperHeuristic(

mparams->tile_sizes = {cta_tile, warp_tile};

// Use warp specialization on hopper by default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought using warp specialization by default was causing some test failures.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not anymore. I think that was before integrating the warp tile split.

mparams->circular_buffering_strategy =
MatmulParams::CircularBufferingStrategy::WarpSpecialized;

// stages and async mem copy
mparams->circular_buffer_options.smem_circular_buffer_stage = 8;

Expand Down
Loading