Skip to content

Commit

Permalink
Merge in from #3642. Add persistent change
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Jan 31, 2025
1 parent ffa276e commit 86d75de
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
21 changes: 16 additions & 5 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <scheduler/debug_utils.h>
#include <scheduler/hopper_multi_matmul.h>
#include <scheduler/matmul.h>
#include <scheduler/matmul_heuristic.h>
#include <scheduler/matmul_utils.h>
#include <scheduler/mma_utils.h>
#include <scheduler/tools/abstract_tensor.h>
Expand Down Expand Up @@ -105,11 +106,6 @@ void HopperMultipleMatmulScheduler::validate() const {
"Hopper matmul scheduler does not support scheduling persistent split-K kernels");
}

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs,
"Hopper matmul scheduler TEMPORARILY does not support persistent scheduling of tiles yet");

NVF_CHECK(
params_->tiling_strategy !=
MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
Expand Down Expand Up @@ -374,6 +370,21 @@ std::vector<std::vector<MatmulDimRole>> HopperMultipleMatmulScheduler::
}
}
}

if (params_->tiling_strategy ==
MatmulParams::TilingStrategy::DistributeTilesAcrossSMs) {
// Persistent kernel scheduling
if (params_->cta_order ==
MatmulParams::TileRasterizationOrder::ColumnMajor) {
tv->reorder(
{{num_device_and_batch_dims_, num_device_and_batch_dims_ + 1}});
}
tv->merge(num_device_and_batch_dims_, num_device_and_batch_dims_ + 1);

const int64_t num_sms =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
tv->split(num_device_and_batch_dims_, num_sms);
}
}
return all_merged_roles;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4262,9 +4262,6 @@ class MLPBenchmarkTest
// warp specialization requires Hopper+
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
}
if (test_params.persistent_kernel) {
GTEST_SKIP() << "persistent kernel tests are currently disabled";
}
}
};

Expand Down Expand Up @@ -4508,6 +4505,9 @@ INSTANTIATE_TEST_SUITE_P(
MLPBenchmarkTestParams{
.warp_specialization = true,
.persistent_kernel = false},
MLPBenchmarkTestParams{
.warp_specialization = false,
.persistent_kernel = true},
MLPBenchmarkTestParams{
.warp_specialization = true,
.persistent_kernel = true}),
Expand Down

0 comments on commit 86d75de

Please sign in to comment.