Skip to content

Fix scheduling of split-K with smem_epilogue on Hopper #4257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 17, 2025

Conversation

jacobhinkle
Copy link
Collaborator

Introduces cacheBefore to match cacheAfter utility, which just propagates entries in graph_ corresponding to new IDs in the cached tensors. Also avoids re-scheduling tensors if they are split-K sum tensors.

There is a current limitation for 32-bit outputs where we skip stmatrix but our current vectorized stores encounter 2-way bank conflicts. This is probably not that important to perf and can be fixed in scheduling of that store in another PR.

Fixes #4159

@jacobhinkle jacobhinkle requested a review from rdspring1 April 16, 2025 13:11
@jacobhinkle jacobhinkle changed the title Jh/fix splitk smem epilogue hopper Fix scheduling of split-K with smem_epilogue on Hopper Apr 16, 2025
@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Apr 16, 2025

Review updated until commit d302129

Description

  • Introduced cacheBefore to match cacheAfter utility.

  • Modified scheduling logic to use cacheBefore for epilogue tensors.

  • Added defaultHopperParams function for test configuration.

  • Added test case for split-K with smem epilogue on Hopper.


Changes walkthrough 📝

Relevant files
Enhancement
hopper_multi_matmul.cpp
Update scheduling to use cacheBefore                                         

csrc/scheduler/hopper_multi_matmul.cpp

  • Removed cacheAfter method.
  • Modified scheduleEpilogue to use cacheBefore.
  • Updated logic to skip scheduling for split-K sum tensors.
  • +7/-42   
    multi_matmul.cpp
    Add cacheBefore and update cacheAfter                                       

    csrc/scheduler/multi_matmul.cpp

  • Added cacheBefore method.
  • Retained cacheAfter method with updates.
  • +54/-0   
    Tests
    test_matmul.cpp
    Add test for split-K smem epilogue                                             

    tests/cpp/test_matmul.cpp

  • Added defaultHopperParams function.
  • Added test case HSS_NT_SplitKTMAStore for split-K with smem epilogue.
  • +82/-0   
    Formatting
    hopper_multi_matmul.h
    Remove cacheAfter declaration                                                       

    csrc/scheduler/hopper_multi_matmul.h

    • Removed cacheAfter declaration.
    +0/-8     
    multi_matmul.h
    Add cacheBefore and cacheAfter declarations                           

    csrc/scheduler/multi_matmul.h

    • Added cacheBefore and cacheAfter declarations.
    +14/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Cache Functionality

    The introduction of cacheBefore and the removal of cacheAfter in hopper_multi_matmul.cpp should be validated to ensure that the caching mechanism is correctly implemented and does not introduce any performance regressions.

    // Manually schedule register cache and output TensorView
    for (Val* dv : fusion_->outputs()) {
      TensorView* d = dv->as<TensorView>();
      NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
      TensorView* dc = d->definition()->input(0)->as<TensorView>();
    
      // The chain of operations storing data to global memory:
      //   registers -> (stmatrix) -> smem -> (tma_store) -> gmem
      TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
    
      std::vector<TensorView*> tvs_to_schedule{d, d_smem};
      bool dc_is_mma_result =
          std::find(mma_results_.begin(), mma_results_.end(), dc) !=
          mma_results_.end();
      bool dc_is_splitk_sum = params_->splitk_factor > 1 &&
          std::find(splitk_sums_.begin(), splitk_sums_.end(), dc) !=
              splitk_sums_.end();
    
      if (!dc_is_mma_result && !dc_is_splitk_sum) {
        // Skip scheduling dc if it is an mma_result. This can happen if we are
    Code Duplication

    The cacheAfter function is duplicated in both hopper_multi_matmul.cpp and multi_matmul.cpp. This duplication should be reviewed to ensure consistency and maintainability.

      TensorView* orig,
      LoadStoreOpType op_type,
      CacheOp cache_op,
      bool propagate_allocation_domain) {
    const std::vector<IterDomain*> orig_alloc = orig->getMaybeAllocationDomain();
    
    TensorView* c =
        orig->cacheAfter(op_type, cache_op, propagate_allocation_domain);
    
    if (propagate_allocation_domain) {
      const std::vector<IterDomain*> cache_alloc = c->getMaybeAllocationDomain();
      NVF_ERROR(orig_alloc.size() == cache_alloc.size());
      for (size_t i : arange(orig_alloc.size())) {
        ValGroup vg = graph_->toGroup(orig_alloc[i]);
        graph_->initializeVal(cache_alloc[i], vg);
      }
    }
    
    const std::vector<IterDomain*> orig_logical =
        TensorDomain::noReductions(orig->getLogicalDomain());
    const std::vector<IterDomain*> cache_logical = c->getLogicalDomain();
    // in split-K we do rFactor which gives us a full = sum(partial)
    // where partial has root domain that matches the logical domain of the
    // original tensor. The logical domain contains Iteration transforms of the
    // Reduction axis in the original mma output.
    NVF_ERROR(orig_logical.size() == cache_logical.size());
    for (size_t i : arange(orig_logical.size())) {
      ValGroup vg = graph_->toGroup(orig_logical[i]);
      graph_->initializeVal(cache_logical[i], vg);
    }
    
    return c;
    Bank Conflicts

    The test HSS_NT_SplitKTMAStore mentions a known issue with 2-way bank conflicts for 32-bit outputs. This should be addressed or documented further to ensure that the performance impact is understood and mitigated.

    TEST_F(HopperMatmulTest, HSS_NT_SplitKTMAStore) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
      auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {0});
    
      // Reorder the accumulator as [M, N, K]
      // [K, M, N] -> [M, N, K]
      tv2->reorder({{-3, -1}});
      tv2->commitLeafToLogical();
    
      fusion.addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto t0 = at::randn({K, M, 1}, options);
      auto t1 = at::randn({K, 1, N}, options);
      auto out_ref =
          at::matmul(t0.squeeze().t().to(at::kFloat), t1.squeeze().to(at::kFloat));
    
      MatmulParams mparams = defaultHopperParams();
      mparams.use_smem_epilogue = true;
      mparams.splitk_factor = 2;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1});
      // TODO: Either enable stmatrix for 32-bit outputs or fix current 2-way bank
      // conflict by scheduling the vectorized store properly
      auto bank_conflicts = getBankConflictInfo(ke.compiledKernel()->kernel());
      EXPECT_EQ(bank_conflicts.size(), 1);
      for (const auto& [expr, conflict_ways] : bank_conflicts) {
        int64_t input_ways, output_ways;
        std::tie(input_ways, output_ways) = conflict_ways;
        EXPECT_EQ(input_ways, 0);
        EXPECT_EQ(output_ways, 2);
      }
      auto cg_outputs = ke.run({t0, t1});
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      NVF_CHECK(at::allclose(
          cg_outputs[0].as<at::Tensor>(), out_ref, 1e-6 * K, 1e-6 * K));
    }

    This can be used to clarify the other tests too, but I will do that in
    another PR and run codediff etc.
    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    @jacobhinkle
    Copy link
    Collaborator Author

    Failures are unrelated

    @jacobhinkle jacobhinkle merged commit 9b9cd8f into main Apr 17, 2025
    52 of 56 checks passed
    @jacobhinkle jacobhinkle deleted the jh/fix_splitk_smem_epilogue_hopper branch April 17, 2025 18:22
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    Error scheduling Hopper matmul with use_smem_epilogue and splitk
    2 participants