diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index b242369b5a2..a6972baa72d 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -105,11 +106,6 @@ void HopperMultipleMatmulScheduler::validate() const { "Hopper matmul scheduler does not support scheduling persistent split-K kernels"); } - NVF_CHECK( - params_->tiling_strategy != - MatmulParams::TilingStrategy::DistributeTilesAcrossSMs, - "Hopper matmul scheduler TEMPORARILY does not support persistent scheduling of tiles yet"); - NVF_CHECK( params_->tiling_strategy != MatmulParams::TilingStrategy::DistributeStagesAcrossSMs, @@ -374,6 +370,21 @@ std::vector> HopperMultipleMatmulScheduler:: } } } + + if (params_->tiling_strategy == + MatmulParams::TilingStrategy::DistributeTilesAcrossSMs) { + // Persistent kernel scheduling + if (params_->cta_order == + MatmulParams::TileRasterizationOrder::ColumnMajor) { + tv->reorder( + {{num_device_and_batch_dims_, num_device_and_batch_dims_ + 1}}); + } + tv->merge(num_device_and_batch_dims_, num_device_and_batch_dims_ + 1); + + const int64_t num_sms = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + tv->split(num_device_and_batch_dims_, num_sms); + } } return all_merged_roles; } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index e2739bc50e6..a63d6feaebe 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4262,9 +4262,6 @@ class MLPBenchmarkTest // warp specialization requires Hopper+ NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0); } - if (test_params.persistent_kernel) { - GTEST_SKIP() << "persistent kernel tests are currently disabled"; - } } }; @@ -4508,6 +4505,9 @@ INSTANTIATE_TEST_SUITE_P( MLPBenchmarkTestParams{ .warp_specialization = true, .persistent_kernel = false}, + MLPBenchmarkTestParams{ + .warp_specialization = false, + .persistent_kernel = true}, MLPBenchmarkTestParams{ .warp_specialization = true, .persistent_kernel = true}),