Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement persistent matmul scheduling #3812

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 3, 2025

Stacked on #3642

This is a followup to #3792 that implements persistent scheduling.

There is a current limitation that affects both persistent scheduling and "grid swizzling": if MatmulOp or LinearOp are present in the fusion, we will hit inlining errors. This is because in that case we have a non-trivial AxisMapping on the MmaOp. The missing input dimensions are not tracked through the scheduling transforms (merges and splits) required for either grid swizzling or persistent scheduling. Because of this, I introduced three new parametrized tests matching the original MLPBenchmarkTests but with _BroadcastInputs suffix. These tests use fusedMultiplySum instead of linear. The persistent variant of the non BroadcastInputs tests are skipped until we fix the inlining issue.

I currently observe a correctness issue in the MLPBenchmarkTest.FwdEpilogueFusion_BroadcastInputs test regardless of parametrization. This means that we are getting incorrect results even for data parallel scheduling. I confirmed this test also fails on main. I currently skip this test with a warning mesage.

jacobhinkle and others added 29 commits December 23, 2024 20:54
I think this covers the motivation for #3616
There is still one case that fails, which we should fix. I'll create an
issue for it.
@jacobhinkle jacobhinkle requested a review from rdspring1 February 3, 2025 16:04
@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 3, 2025

Review updated until commit 0cbb3e6

Description

  • Implement persistent scheduling for Hopper matmul kernels

  • Add support for warp specialization on Hopper by default

  • Parametrize MLP Benchmark tests to include persistent and warp specialization configurations

  • Add new tests for broadcast inputs in MLP benchmarks


Changes walkthrough 📝

Relevant files
Enhancement
hopper_multi_matmul.cpp
Add persistent scheduling support                                               

csrc/scheduler/hopper_multi_matmul.cpp

  • Include matmul_heuristic.h
  • Remove temporary check for persistent scheduling
  • Add persistent kernel scheduling logic
  • Update block parallelization based on tiling strategy
  • +41/-16 
    matmul_utils.cpp
    Set warp specialization default                                                   

    csrc/scheduler/matmul_utils.cpp

    • Set warp specialization as default on Hopper
    +4/-0     
    test_matmul.cpp
    Update MLPBenchmarkTest for persistent kernels                     

    tests/cpp/test_matmul.cpp

  • Parametrize MLPBenchmarkTest to include persistent and warp
    specialization
  • Add tests for broadcast inputs
  • Skip persistent kernel tests for unsupported operations
  • Adjust parameters for smem and register constraints
  • +235/-80

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The PR introduces persistent kernel scheduling, but there are known issues with inlining errors when MatmulOp or LinearOp are present in the fusion. This should be addressed to ensure correctness.

    if (params_->tiling_strategy != MatmulParams::TilingStrategy::OneTilePerCTA) {
      NVF_CHECK(
          params_->splitk_factor == 1,
          "Hopper matmul scheduler does not support scheduling persistent split-K kernels");
    }
    
    NVF_CHECK(
        params_->tiling_strategy !=
            MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
        "Hopper matmul scheduler does not support distributing stages across SMs a la stream-K");
    Failing Tests

    The FwdEpilogueFusion_BroadcastInputs test is currently failing. This should be investigated and resolved to ensure the correctness of the persistent kernel implementation.

    TEST_P(MLPBenchmarkTest, FwdEpilogueFusion_BroadcastInputs) {
      GTEST_SKIP() << "THIS TEST IS CURRENTLY FAILING" << std::endl;
    Test Skips

    Multiple tests are skipped due to unsupported features or known issues. These should be addressed or documented to ensure comprehensive testing.

    auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
    auto a_ref = at::randn({M, K}, options);

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    Looks like some overlap with #3642. Do you plan to merge it first?

    MLPBenchmarkTestParams{
    .warp_specialization = true,
    .persistent_kernel = true}),
    [](const testing::TestParamInfo<MLPBenchmarkTestParams>& info) {
    std::stringstream ss;
    ss << (info.param.persistent_kernel ? "persistent" : "data_parallel");
    ss << (info.param.persistent_kernel ? "persistent" : "dataparallel");
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is this rename from a bad merge?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I just hadn't merged #3642 in a while I think.

    @@ -308,6 +308,10 @@ bool fillDefaultHopperHeuristic(

    mparams->tile_sizes = {cta_tile, warp_tile};

    // Use warp specialization on hopper by default
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I thought using warp specialization by default was causing some test failures.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Not anymore. I think that was before integrating the warp tile split.

    @jacobhinkle jacobhinkle force-pushed the jh/persistent_kernel_impl branch from 52d2bca to 07c93c6 Compare February 6, 2025 14:03
    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    Grrr. Failures on ampere. Will fix before merging.

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle jacobhinkle requested a review from rdspring1 February 7, 2025 16:19
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants