Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue multiple wgmma operations when CTA k dim is a multiple of 16 #3616

Merged
merged 9 commits into from
Feb 3, 2025

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Dec 19, 2024

This PR fixes the incorrect results issue when k dimension for CTA tile is a multiple of getK(mma_macro).

Why?

  • In scheduleMmaResults, we need to split the k reduction by getK(mma_macro). A serial reduction will add the results from wgmma along k-dimension.

Details

  • Modified transformLikeMmaOutput function to not be used in scheduleMmaResults.

@rdspring1
Copy link
Collaborator Author

!test

Comment on lines 4036 to 4037
// NOTE Certain combinations of cta k dimension and circular buffer
// prefetching can get incorrect results.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀


mparams.supported_vec_size = {8, 8, 4};
mparams.supported_vec_size = {8, 8, 8};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, although I think it currently has no meaning until we start handling epilogue inputs with supported vec size.

@rdspring1
Copy link
Collaborator Author

!test

@rdspring1 rdspring1 force-pushed the hopper_matmul_cta_k_fix branch 4 times, most recently from d366352 to 0c784e0 Compare December 20, 2024 23:56
@rdspring1
Copy link
Collaborator Author

!test

jacobhinkle added a commit that referenced this pull request Jan 2, 2025
I think this covers the motivation for #3616
@rdspring1 rdspring1 force-pushed the hopper_matmul_cta_k_fix branch from e3826e2 to 603d5c9 Compare January 29, 2025 19:31
Copy link

github-actions bot commented Jan 29, 2025

PR Reviewer Guide 🔍

(Review updated until commit 74f2056)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Incorrect Results

The transformLikeMmaOutput function has been modified to not be used in scheduleMmaResults. This change may cause incorrect results when the k dimension for CTA tile is a multiple of getK(mma_macro). The reviewer should verify that the new implementation produces correct results.

void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
  NVF_ERROR(
      tv->domain()->loop().size() >= 4,
      "transformLikeMmaOutput requires at least four iterDomains but ",
      tv->toString(),
      " only has ",
      tv->domain()->loop().size(),
      ".");

  // Original: [..., Mo, No, Mi, Ni]
  tv->split(-2, getM(params_->mma_macro));
  tv->split(-1, getN(params_->mma_macro));
  // After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
  tv->reorder({{-3, -2}});
  // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
  tv->merge(-4);
  // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
  tv->axis(-3)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
WAR Hazard

The scheduleMmaResults function has been modified to split the k dimension of warp tile only if it is larger than k dimension of mma macro. This change may introduce a WAR hazard between aliased shared memory. The reviewer should verify that the new implementation does not introduce any WAR hazards.

// Original: [..., Mo, No, Mi, Ni, Ki]
mma_result->split(-3, getM(params_->mma_macro));
mma_result->split(-2, getN(params_->mma_macro));

// Split k dimension of warp tile only if it is larger than k dimension of
// mma macro. Inlining can be at incorrect position for circular buffering
// if a reduction iterDomain has iterDomain 1.
if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) {
  mma_result->split(-1, getK(params_->mma_macro));
  // After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
  mma_result->reorder({{-5, -4}});
  // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
  mma_result->reorder({{-2, -4}});
  // After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
  mma_result->merge(-6);
  // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
  mma_result->axis(-5)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
} else {
  // After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
  mma_result->reorder({{-4, -3}});
  // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
  mma_result->merge(-5);
  // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
  mma_result->axis(-4)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
}
Incorrect Placement of WGMA Syncs

The test cases in this file have been modified to use a different tile size and mma macro. However, the placement of WGMA syncs may be incorrect, leading to incorrect results. The reviewer should verify that the WGMA syncs are correctly placed.

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  auto a_ref = at::randn({K, M, 1}, options);
  auto b_ref = at::randn({K, 1, N}, options);
  auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 256, 32);
  gemm_tile.warp_tile = GemmTile(64, 256, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_256_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  mparams.use_smem_epilogue = true;
  mparams.cluster_dims = {2, 1, 1};
  mparams.promote_prologue_smem_reuse = true;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  std::vector<c10::IValue> inputs = {a_ref, b_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
  auto cg_outputs = ke.run(inputs);
  ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
      ke.compiledKernel()->kernel()));

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 2048, N = 2048, K = 8192;
  const auto dtype = DataType::Half;

  auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, K
  auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // N, K
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

  auto tv3 = castOp(DataType::Half, tv2);
  fusion.addOutput(tv3);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);

@rdspring1 rdspring1 marked this pull request as ready for review January 29, 2025 19:32
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I didn't update the cta tile K dim in the tests in #3642. It will be good to merge this and then verify we don't have any regressions in #3642 next.

@rdspring1 rdspring1 force-pushed the hopper_matmul_cta_k_fix branch from e68f85c to 74f2056 Compare January 31, 2025 01:11
@rdspring1
Copy link
Collaborator Author

!test

@rdspring1
Copy link
Collaborator Author

@jacobhinkle I ran into issues with tests HopperMatmulTest.MLPBenchmarkFwdGEMM and LinearNodeTranslationTest.AutomaticSchedulerLinearNode/2dA_2dB_1dBias.

For HopperMatmulTest.MLPBenchmarkFwdGEMM, the issue was not enough k tiles to fill circular buffer pipeline. I used the best configuration from the matmul sprint to fix this.

LinearNodeTranslationTest.AutomaticSchedulerLinearNode/2dA_2dB_1dBias is more troublesome. inlineMost is placing the computeAt position at the wrong place for circular buffering.

Why? k=16 in default hopper matmul heuristics. When warp_tile.k == mma_macro.k, a size-1 reduction iterDomain is created, so inlineMost can be pushed further to the right.

Fixes:

  • Bump k=16 to k=64 in default macro
  • Make a separate path when warp_tile.k == mma_macro.k to not split by mma_macro.k

@jacobhinkle
Copy link
Collaborator

Why? k=16 in default hopper matmul heuristics. When warp_tile.k == mma_macro.k, a size-1 reduction iterDomain is created, so inlineMost can be pushed further to the right.

Maybe we could handle this explicitly in the setUpInlining() phase of the scheduler?

@rdspring1
Copy link
Collaborator Author

I tried modifying setupInlining this afternoon. InlineMost behaves differently than InlineSelectedAt so InlineSelectedAt won't place the computeAt at the cta tile position.

@jacobhinkle
Copy link
Collaborator

Why? k=16 in default hopper matmul heuristics. When warp_tile.k == mma_macro.k, a size-1 reduction iterDomain is created, so inlineMost can be pushed further to the right.

I am wondering why I don't see this issue on #3642. I do the K split also: https://github.com/NVIDIA/Fuser/pull/3642/files#diff-e7aea6139145124edbbea47b4cb1541d90715e3be70e54051e42eb14eec7588fR46.

Copy link

github-actions bot commented Jan 31, 2025

Review updated until commit 2e703f8

Description

  • Corrected transformLikeMmaOutput function to handle cases where is_mma_result is not needed.

  • Updated scheduleMmaResults to split k dimension of warp tile conditionally.

  • Increased k dimension in test cases to 32 or 64 for better performance.

  • Adjusted warp_tile in fillDefaultHopperHeuristic to issue multiple wgmma instructions per warp group.


Changes walkthrough 📝

Relevant files
Formatting
allocation.cpp
Fix typo in comment                                                                           

csrc/device_lower/pass/allocation.cpp

  • Corrected a typo in a comment.
+1/-1     
insert_syncs.cpp
Fix typos in comments                                                                       

csrc/device_lower/pass/insert_syncs.cpp

  • Corrected typos in comments.
+2/-2     
Bug fix
hopper_multi_matmul.cpp
Update transformLikeMmaOutput and scheduleMmaResults         

csrc/scheduler/hopper_multi_matmul.cpp

  • Updated transformLikeMmaOutput to remove is_mma_result parameter.
  • Modified scheduleMmaResults to conditionally split k dimension of warp
    tile.
  • Updated scheduleEpilogue and scheduleSplitKSum to use the modified
    transformLikeMmaOutput.
  • +44/-17 
    hopper_multi_matmul.h
    Update transformLikeMmaOutput signature                                   

    csrc/scheduler/hopper_multi_matmul.h

  • Updated transformLikeMmaOutput function signature to remove
    is_mma_result parameter.
  • +1/-1     
    Enhancement
    matmul_utils.cpp
    Increase k dimension in warp_tile                                               

    csrc/scheduler/matmul_utils.cpp

  • Increased k dimension in warp_tile to issue multiple wgmma
    instructions per warp group.
  • +4/-3     
    Tests
    test_matmul.cpp
    Update test cases with increased k dimension                         

    tests/cpp/test_matmul.cpp

  • Increased k dimension in test cases to 32 or 64.
  • Disabled failing tests due to incorrect results.
  • +33/-30 
    test_matmul_scheduler.cpp
    Update test cases with increased k dimension                         

    tests/cpp/test_matmul_scheduler.cpp

  • Updated gemm_tile configurations in test cases.
  • Increased k dimension in test cases.
  • +4/-4     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Concern

    The changes to transformLikeMmaOutput and scheduleMmaResults may impact performance. Ensure that the new logic does not introduce unnecessary overhead or reduce efficiency.

    #include <utils.h>
    #include <val_graph.h>
    #include <val_graph_visitor.h>
    
    // NOTE: included to avoid compilation error caused by missing destructor in
    // 'SchedulerRuntimeInfo'
    #include <runtime/executor_utils.h>
    #include "mma_type.h"
    
    namespace nvfuser {
    
    void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) {
      NVF_ERROR(
          tv->domain()->loop().size() >= 4,
          "transformLikeMmaOutput requires at least four iterDomains but ",
          tv->toString(),
          " only has ",
          tv->domain()->loop().size(),
          ".");
    
      // Original: [..., Mo, No, Mi, Ni]
      tv->split(-2, getM(params_->mma_macro));
      tv->split(-1, getN(params_->mma_macro));
      // After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
      tv->reorder({{-3, -2}});
      // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
      tv->merge(-4);
      // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
      tv->axis(-3)->parallelize(ParallelType::TIDy);
      // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
    }
    
    MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
      ValGroup vg = graph_->toGroup(id);
      auto it = id_roles_.find(vg);
      NVF_ERROR(it != id_roles_.end());
      return it->second;
    }
    
    void HopperMultipleMatmulScheduler::validate() const {
      const auto device_prop = at::cuda::getCurrentDeviceProperties();
      const int cc = device_prop->major * 10 + device_prop->minor;
      NVF_ERROR(
          cc >= 90 && cc < 100, "This matmul scheduler is restricted to Hopper.");
    
      if (params_->tiling_strategy != MatmulParams::TilingStrategy::OneTilePerCTA) {
        NVF_CHECK(
            params_->splitk_factor == 1,
            "Hopper matmul scheduler does not support scheduling persistent split-K kernels");
      }
    
      NVF_CHECK(
          params_->tiling_strategy !=
              MatmulParams::TilingStrategy::DistributeTilesAcrossSMs,
          "Hopper matmul scheduler TEMPORARILY does not support persistent scheduling of tiles yet");
    
      NVF_CHECK(
          params_->tiling_strategy !=
              MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
          "Hopper matmul scheduler does not support distributing stages across SMs a la stream-K");
    
      NVF_CHECK(
          params_->buffering_loop_level ==
              MatmulParams::BufferingLoopLevel::CTATiles,
          "Hopper matmul scheduler only supports cooperatively buffering at the CTA level (no ping-pong)");
    }
    
    void HopperMultipleMatmulScheduler::run() {
      // Clears memory spaces on intermediate tensors, calls
      // cache{After,Before,Fork} on inputs and outputs
      cacheInputsAndOutputs();
    
      // Finds matmul patterns and translates them to MmaOps, then finds tensor
      // and dimension roles for all tensors in the fusion
      findPatterns();
      translatePatterns();
      findRoles();
    
      // Defines acw_smem/bcw_smem and acr/bcr by possibly calling cacheAfter.
      // This also collects mma_results_
      defineOperandCaches();
    
      inspectPrologues();
    
      setCGADims();
    
      scheduleOperands();
    
      // schedule mma instruction output (mma_result)
      scheduleMmaResults();
    
      // schedule epilogue
      scheduleEpilogue();
    
      // schedule splitk_sum
      scheduleSplitKSum();
    
      setUpInlining();
    
      // set up circular buffering. This must come after everything up to
      // mma_result is scheduled, since everything in the main loop will need to
      // be rotated
      setUpCircularBuffering();
    }
    
    void HopperMultipleMatmulScheduler::cacheInputsAndOutputs() {
      // Make sure we don't have global memory set on intermediate tensors from
      // fusion segmentation
      scheduler_utils::clearMemorySpace(fusion_);
    
      // Cache inputs
      scheduler_utils::cacheInputs(fusion_, /*unroll=*/true);
    
      // Cache and fork outputs
      scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true);
    }
    
    void HopperMultipleMatmulScheduler::defineOperandCaches() {
      cacheOperandsToSmem(as_, acw_smems_);
      cacheOperandsToSmem(bs_, bcw_smems_);
    
      // Now that we are finished possibly redefining the inputs to the MmaOps,
      // we can set the macro for those ops
      for (TensorView* mma_result : mma_results_) {
        MmaOp* mma = dynamic_cast<MmaOp*>(mma_result->definition());
        NVF_ERROR(mma != nullptr);
        mma->setMacro(params_->mma_macro);
      }
    }
    
    void HopperMultipleMatmulScheduler::cacheOperandsToSmem(
        const std::vector<TensorView*>& operands,
        std::vector<TensorView*>& smem_operands) {
      // Use cp.async.bulk (tma) as requested in scheduler params.
      smem_operands.resize(operands.size(), nullptr);
      for (size_t i : c10::irange(operands.size())) {
        TensorView* operand = operands[i];
    
        NVF_ERROR(operand->uses().size() == 1);
        smem_operands[i] = ir_utils::consumerTvsOf(operand).at(0);
    
        LoadStoreOpType load_op = params_->async_gmem_load_operands
            ? LoadStoreOpType::CpAsyncBulkTensorTile
            : LoadStoreOpType::Set;
    
        smem_operands[i]->definition()->as<LoadStoreOp>()->setOpType(load_op);
        smem_operands[i]->setMemoryType(MemoryType::Shared);
      }
    }
    
    void HopperMultipleMatmulScheduler::swizzleBlockTiles(
        TensorView* tv,
        std::vector<MatmulDimRole>& outer_dim_roles) {
      if (params_->grid_swizzle_factor != 1) {
        // Find position of outer M and N dims in schedule_.tiled
        int64_t Mo_pos = -1, No_pos = -1;
        for (size_t i : c10::irange(outer_dim_roles.size())) {
          if (outer_dim_roles[i] == MatmulDimRole::M) {
            Mo_pos = (int64_t)i;
          } else if (outer_dim_roles[i] == MatmulDimRole::N) {
            No_pos = (int64_t)i;
          }
        }
    
        int factor = std::max(1, params_->grid_swizzle_factor); // must be >=1
        switch (params_->cta_order) {
          case MatmulParams::TileRasterizationOrder::RowMajor:
            // split   [I1, I2/factor, factor]
            // reorder [I1, factor, I2/factor]
            // merge   [I1*factor, I2/factor]
            // where I1 and I2 are the outer M and N dimensions, respectively
            if (No_pos >= 0) {
              tv->split(No_pos, factor);
              // If No_pos < Mo_pos, then the split above shifts Mo_pos by one
              if (No_pos < Mo_pos) {
                Mo_pos++;
              }
              tv->reorder({{No_pos, No_pos + 1}});
              if (Mo_pos >= 0) {
                tv->merge(Mo_pos, No_pos);
              } else {
                // M is missing, so we skip the merge above. In this case we
                // should update the dim roles to reflect the new split axis.
                outer_dim_roles.insert(
                    outer_dim_roles.begin() + No_pos, MatmulDimRole::N);
              }
            }
            break;
    
          case MatmulParams::TileRasterizationOrder::ColumnMajor:
            // split   [I1/factor, factor, I2]
            // reorder [I1/factor, I2, factor]
            // merge   [I1/factor, I2*factor]
            // where I1 and I2 are the outer M and N dimensions, respectively
            if (Mo_pos >= 0) {
              tv->split(Mo_pos, factor);
              // If No_pos < Mo_pos, then the split above shifts Mo_pos by one
              if (No_pos > Mo_pos) {
                No_pos++;
              }
              if (No_pos >= 0) {
                tv->reorder({{Mo_pos + 1, No_pos}});
                tv->merge(Mo_pos + 1, No_pos);
              } else {
                // N is missing, so we skip the merge above. In this case we
                // should update the dim roles to reflect the new split axis.
                outer_dim_roles.insert(
                    outer_dim_roles.begin() + Mo_pos, MatmulDimRole::M);
              }
            }
        }
      }
    }
    
    TensorView* HopperMultipleMatmulScheduler::cacheAfter(
        TensorView* orig,
        LoadStoreOpType op_type,
        CacheOp cache_op,
        bool propagate_allocation_domain) {
      const std::vector<IterDomain*> orig_alloc = orig->getMaybeAllocationDomain();
    
      TensorView* c =
          orig->cacheAfter(op_type, cache_op, propagate_allocation_domain);
    
      if (propagate_allocation_domain) {
        const std::vector<IterDomain*> cache_alloc = c->getMaybeAllocationDomain();
        NVF_ERROR(orig_alloc.size() == cache_alloc.size());
        for (size_t i : c10::irange(orig_alloc.size())) {
          ValGroup vg = graph_->toGroup(orig_alloc[i]);
          graph_->initializeVal(cache_alloc[i], vg);
        }
      }
    
      const std::vector<IterDomain*> orig_logical =
          TensorDomain::noReductions(orig->getLogicalDomain());
      const std::vector<IterDomain*> cache_logical = c->getLogicalDomain();
      // in split-K we do rFactor which gives us a full = sum(partial)
      // where partial has root domain that matches the logical domain of the
      // original tensor. The logical domain contains Iteration transforms of the
      // Reduction axis in the original mma output.
      NVF_ERROR(orig_logical.size() == cache_logical.size());
      for (size_t i : c10::irange(orig_logical.size())) {
        ValGroup vg = graph_->toGroup(orig_logical[i]);
        graph_->initializeVal(cache_logical[i], vg);
      }
    
      return c;
    }
    
    std::vector<std::vector<MatmulDimRole>> HopperMultipleMatmulScheduler::
        blockTileTensors(const std::vector<TensorView*>& tvs) {
      if (canonical_dim_ordering_.empty()) {
        canonical_dim_ordering_ =
            mma_utils::canonicalDimOrdering(tensor_roles_, id_roles_, *graph_);
      }
    
      std::vector<std::vector<MatmulDimRole>> all_merged_roles;
      for (TensorView* tv : tvs) {
        // Find dimensions in canonical_dim_ordering_ that exist in tv's loop
        // domain. Reorder those according to the canonical dim ordering then
        std::unordered_map<ValGroup, IterDomain*> tv_dims;
        std::unordered_set<MatmulDimRole> axis_roles;
        for (IterDomain* id : tv->getLoopDomain()) {
          ValGroup vg = graph_->toGroup(id);
          tv_dims.emplace(vg, id);
          // track axis roles in this tensor to use in makeTile
          auto it = id_roles_.find(vg);
          NVF_ERROR(it != id_roles_.end());
          axis_roles.insert(it->second);
        }
        std::vector<IterDomain*> new_loop;
        new_loop.reserve(tv->nDims());
        for (const ValGroup& vg : canonical_dim_ordering_) {
          auto it = tv_dims.find(vg);
          if (it != tv_dims.end()) {
            new_loop.push_back(it->second);
          }
        }
        NVF_ERROR((int64_t)new_loop.size() == tv->nDims());
        tv->setLoopDomain(new_loop);
    
        // There could be multiple dimensions with the same role at this point, so
        // now we collect them. After this, tv will be at most 4 dimensions e.g.
        // BMNK based on canonical_dim_ordering_, with any of these dimensions
        // possibly missing.
        mma_utils::mergeConsecutiveAxesWithSameRole(tv, id_roles_, graph_);
    
        // Find order the axes that are present in the merged tensor
        std::vector<MatmulDimRole> merged_roles;
        merged_roles.reserve(tv->nDims());
        for (const ValGroup& vg : canonical_dim_ordering_) {
          MatmulDimRole role = id_roles_[vg];
          if (axis_roles.count(role) != 0) {
            if (merged_roles.empty() || merged_roles.back() != role) {
              merged_roles.push_back(role);
            }
          }
        }
        NVF_ERROR(merged_roles.size() == axis_roles.size());
    
        // TODO: (to be pursued after the multi-matmul refactor is fully merged)
        // this currently creates a separate AbstractMatmulTensor for each
        // TensorView. Instead, we should create a single AbstractMatmulTensor
        // then apply it (with "forwarding") to each TV instead. We already cache
        // a vector<ValGroup> as canonical_dim_ordering_ so AbstractTensor
        // scheduling is the next step in this modernization.
        mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);
    
        swizzleBlockTiles(tv, merged_roles);
    
        all_merged_roles.push_back(merged_roles);
    
        if (params_->splitk_factor > 1) {
          // Outer K dimension in tv is in same position found in merged_roles
          for (size_t i : c10::irange(merged_roles.size())) {
            if (merged_roles[i] == MatmulDimRole::K) {
              tv->split((int64_t)i, params_->splitk_factor, /*inner*/ false);
            }
          }
        }
      }
      return all_merged_roles;
    }
    
    void HopperMultipleMatmulScheduler::inspectPrologues() const {
      for (TensorView* mma_result : mma_results_) {
        for (Val* v : mma_result->definition()->inputs()) {
          TensorView* op_input = v->as<TensorView>();
    
          // We currently require all operands to lie in smem, meaning we cannot yet
          // handle any prologue computation. This includes `BroadcastOp` which
          // might be introduced when translating a MatmulOp or LinearOp to MmaOp.
          Expr* def = op_input->definition();
          NVF_ERROR(def != nullptr && def->isA<LoadStoreOp>());
          NVF_ERROR(def->input(0)->isFusionInput());
        }
      }
    }
    
    void HopperMultipleMatmulScheduler::scheduleOperands() {
      NVF_CHECK(
          params_->async_gmem_load_operands,
          "Hopper matmul scheduler currently requires TMA to be enabled");
      auto scheduleBranch = [&](const std::vector<TensorView*>& gmem_operands,
                                const std::vector<TensorView*>& smem_operands,
                                MmaOperand operand_type) {
        blockTileTensors(smem_operands);
        for (TensorView* tv : smem_operands) {
          if (params_->promote_prologue_smem_reuse) {
            tv->promoteReuse();
          }
          mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv);
          MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(tv);
          tv->applyMmaSwizzleForTMALoad(swizzle_type);
        }
      };
      scheduleBranch(as_, acw_smems_, MmaOperand::A);
      scheduleBranch(bs_, bcw_smems_, MmaOperand::B);
    }
    
    void HopperMultipleMatmulScheduler::parallelizeBlocks(
        const std::vector<TensorView*>& tvs) const {
      for (TensorView* tv : tvs) {
        switch (params_->cta_order) {
          // TODO: Should we instead check the roles of these dimensions to take the
          // outermost two M or N axes?
          case MatmulParams::TileRasterizationOrder::RowMajor:
            tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDx);
            tv->axis(num_device_and_batch_dims_ + 1)
                ->parallelize(ParallelType::BIDy);
            break;
          case MatmulParams::TileRasterizationOrder::ColumnMajor:
            tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDy);
            tv->axis(num_device_and_batch_dims_ + 1)
                ->parallelize(ParallelType::BIDx);
            break;
          default:
            NVF_THROW("Invalid TileRasterizationOrder passed to Matmul scheduler");
        }
      }
    }
    
    void HopperMultipleMatmulScheduler::scheduleMmaResults() {
      GemmTile instruction_tile = getMmaOpShape(params_->mma_macro);
      NVF_CHECK(
          params_->tile_sizes.cta_tile.k == params_->tile_sizes.warp_tile.k,
          "CTA tile must match warp tile K dimension for Hopper matmul but found ",
          toString(params_->tile_sizes));
      // If cta_tile is not divisible by instruction tile the mma instruction will
      // be predicated.
      NVF_CHECK(
          params_->tile_sizes.cta_tile.m % instruction_tile.m == 0 &&
              params_->tile_sizes.cta_tile.n % instruction_tile.n == 0 &&
              params_->tile_sizes.cta_tile.k % instruction_tile.k == 0,
          "CTA tile must be divisible by macro size but found cta_tile: ",
          toString(params_->tile_sizes.cta_tile),
          " and macro: ",
          toString(params_->mma_macro));
    
      // Schedule mma results and propagate forward
      auto all_merged_roles = blockTileTensors(mma_results_);
      parallelizeBlocks(mma_results_);
      for (size_t i : c10::irange(mma_results_.size())) {
        TensorView*& mma_result = mma_results_[i];
        const std::vector<MatmulDimRole>& merged_roles = all_merged_roles[i];
    
        // Test that mma_result logical is MNK
        // TODO: This currently checks leaf domain only which does not necessarily
        // match logical
        // TODO: Lift this constraint. Use commitLeafToLogical if necessary. We
        // might just want to match using id_roles_
        NVF_ERROR(merged_roles.size() >= 3);
        const auto checkSingleDimRole =
            [&merged_roles](int64_t pos, MatmulDimRole expected_role) {
              if (pos < 0) {
                pos += (int64_t)merged_roles.size();
              }
              NVF_ERROR(pos >= 0);
              NVF_ERROR(pos < (int64_t)merged_roles.size());
              const auto& actual_role = merged_roles[(size_t)pos];
              NVF_ERROR(actual_role == expected_role);
            };
        checkSingleDimRole(-3, MatmulDimRole::M);
        checkSingleDimRole(-2, MatmulDimRole::N);
        checkSingleDimRole(-1, MatmulDimRole::K);
    
        // do split-K rFactor to define splitk_sum and smem_epilogue
        if (params_->splitk_factor != 1) {
          // Note that the split-K split is already done in blockTileTensors
          TensorView* splitk_sum = mma_result->rFactor({-4, -1});
          std::swap(splitk_sum, mma_result);
          splitk_sums_.push_back(splitk_sum);
        }
    
        // Original: [..., Mo, No, Mi, Ni, Ki]
        mma_result->split(-3, getM(params_->mma_macro));
        mma_result->split(-2, getN(params_->mma_macro));
    
        // Split k dimension of warp tile only if it is larger than k dimension of
        // mma macro. Inlining can be at incorrect position for circular buffering
        // if a reduction iterDomain has iterDomain 1.
        if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) {
          mma_result->split(-1, getK(params_->mma_macro));
          // After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii]
          mma_result->reorder({{-5, -4}});
          // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii]
          mma_result->reorder({{-2, -4}});
          // After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii]
          mma_result->merge(-6);
          // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
          mma_result->axis(-5)->parallelize(ParallelType::TIDy);
          // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
        } else {
          // After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
          mma_result->reorder({{-4, -3}});
          // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
          mma_result->merge(-5);
          // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
          mma_result->axis(-4)->parallelize(ParallelType::TIDy);
          // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
        }
    
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            mma_result->getLoopDomain());
        mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
        mma_result->axis(-1)->parallelize(ParallelType::Mma);
        mma_result->axis(-2)->parallelize(ParallelType::Mma);
        mma_result->axis(-3)->parallelize(ParallelType::Mma);
      }
    }
    
    void HopperMultipleMatmulScheduler::scheduleEpilogue() {
      std::vector<TensorView*> cached_tvs;
    
      // Propagate to (not including) the splitk output if there is a splitk
      // else this is just mma_results_
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
        auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
        // Load/cache the epilogue inputs if there are any.
        for (auto* c : c_tvs) {
          cached_tvs.push_back(c->cacheAfter());
        }
        propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
      }
    
      if (!params_->use_smem_epilogue) {
        for (Val* dv : fusion_->outputs()) {
          auto* d = dv->as<TensorView>();
          NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
    
          // Schedule the output TV and propagate it back to the outputs of the Mma
          // op.
          blockTileTensors({d});
          parallelizeBlocks({d});
          transformLikeMmaOutput(d);
    
          auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              d->getLoopDomain());
          d->setLoopDomain(s.as<IterDomain*>());
    
          // TODO: We need to check bank conflicts in this path.
          scheduler_utils::BoundedDirectionalTransformPropagator::backward(
              d,
              -1,
              propagate_to,
              scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                  .propagateParallelType());
    
          // We don't respect vectorization_factor as yet. We vectorize the
          // inner-dim with extent 2.
          // TODO: support vectorization_factor.
          d->axis(-1)->parallelize(ParallelType::Vectorize);
          if (!cached_tvs.empty()) {
            scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
          }
        }
      } else {
        constexpr int64_t stmatrix_tile_m = 16;
        constexpr int64_t stmatrix_tile_n = 16;
    
        // TODO: Support tma tile sizes that are a multiple of mma_macro.
        // The wgmma operation creates an output matrix of mma_macro size. The TMA
        // tile is a multiple of the macro size because stmatrix stores results from
        // wgmma to shared memory. For maximum inlining and to reduce shared memory
        // usage, the tma tile is mma_macro size.
        const int64_t tma_m = getM(params_->mma_macro);
        const int64_t tma_n = getN(params_->mma_macro);
    
        fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
        fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
        fusion_->manage("st_matrix_m", tma_m);
        fusion_->manage("st_matrix_n", tma_n);
    
        // Manually schedule register cache and output TensorView
        for (Val* dv : fusion_->outputs()) {
          auto* d = dv->as<TensorView>();
          NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
          auto* dc = d->definition()->input(0)->as<TensorView>();
    
          // NOTE: cacheBefore does not work with blockTileTensors
          TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);
    
          std::vector<TensorView*> tvs_to_schedule{d, d_smem};
    
          bool dc_in_mma_results =
              std::find(mma_results_.begin(), mma_results_.end(), dc) !=
              mma_results_.end();
    
          if (!dc_in_mma_results) {
            // Skip scheduling dc if it is an mma_result. This can happen if we are
            // not casting back to half-precision in the output
            tvs_to_schedule.push_back(dc);
          }
    
          // Set MemoryType
          dc->setMemoryType(MemoryType::Local);
          d_smem->setMemoryType(MemoryType::Shared);
    
          auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2;
    
          if (store_with_stmatrix) {
            // Set LoadStoreOp
            d_smem->definition()->as<LoadStoreOp>()->setOpType(
                LoadStoreOpType::StMatrix);
          }
          d->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::CpAsyncBulkTensorTile);
    
          // Apply the common transforms to dc, d_smem, d
          // After these transforms we schedule the inner two non-reduction loops
          // (instruction tile) of dc and propagate is back till the outputs of mma.
          blockTileTensors(tvs_to_schedule);
          parallelizeBlocks(tvs_to_schedule);
          for (auto tv : tvs_to_schedule) {
            transformLikeMmaOutput(tv);
          }
    
          // Should not propagate if the dc is a mma output as the mma output has
          // already been scheduled.
          if (!dc_in_mma_results) {
            auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
                dc->getLoopDomain());
            dc->setLoopDomain(s.as<IterDomain*>());
            dc->setAllocationDomain(s.as<IterDomain*>(), true);
    
            scheduler_utils::BoundedDirectionalTransformPropagator::backward(
                dc,
                -1,
                propagate_to,
                scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                    .propagateParallelType());
          }
    
          MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);
    
          // [M, N] -> [128(TIDx), N/8 ,  m(2) , n(2)]
          auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              d_smem->getLoopDomain());
          if (swizzle != MmaInputSmemSwizzle::None) {
            // Create tma store allocation domain with swizzle
            mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
          }
          d_smem->setLoopDomain(s.as<IterDomain*>());
    
          if (store_with_stmatrix) {
            // Schedule shared memory cache; Output from StMatrix
            mma_utils::scheduleStMatrixForMmaOutput(
                d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
          }
    
          d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
    
          // Schedule global memory output; Output from TMA Store
          mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
        }
      }
    }
    
    void HopperMultipleMatmulScheduler::scheduleSplitKSum() {
      if (params_->splitk_factor == 1) {
        return;
      }
      for (TensorView* splitk_sum : splitk_sums_) {
        // Always use serial grid reduction for split-K sum
        splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();
        transformLikeMmaOutput(splitk_sum);
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            splitk_sum->getLoopDomain());
        splitk_sum->setLoopDomain(s.as<IterDomain*>());
    Incorrect Results

    The tests have been modified to use larger k dimensions, but the correctness of the results is commented out. Verify that the changes do not introduce incorrect results and re-enable the tests.

      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({K, M, 1}, options);
      auto b_ref = at::randn({K, 1, N}, options);
      auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 32);
      gemm_tile.warp_tile = GemmTile(64, 256, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, HSH_TN_UseScheduler) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({M, 1, K}, options);
      auto b_ref = at::randn({1, N, K}, options);
      auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze().t()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 32);
      gemm_tile.warp_tile = GemmTile(64, 256, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, HSH_NN_UseScheduler) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, M
      auto tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); // N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {1});
    
      // Reorder the accumulator as [M, N, K]
      // [M, K, N] -> [M, N, K]
      tv2->reorder({{-1, -3}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({1, K, M}, options);
      auto b_ref = at::randn({N, K, 1}, options);
      auto out_ref =
          at::matmul(a_ref.squeeze().t(), b_ref.squeeze().t()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 32);
      gemm_tile.warp_tile = GemmTile(64, 256, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {1});
    
      // Reorder the accumulator as [M, N, K]
      // [M, K, N] -> [M, N, K]
      tv2->reorder({{-2, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({M, K, 1}, options);
      auto b_ref = at::randn({1, K, N}, options);
      auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 32);
      gemm_tile.warp_tile = GemmTile(64, 256, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    struct MLPBenchmarkTestParams {
      bool warp_specialization;
      bool persistent_kernel;
    };
    
    class MLPBenchmarkTest
        : public HopperBase,
          public ::testing::WithParamInterface<MLPBenchmarkTestParams> {
     protected:
      MLPBenchmarkTestParams test_params;
      void SetUp() override {
        HopperBase::SetUp();
    
        test_params = GetParam();
    
        if (test_params.persistent_kernel) {
          GTEST_SKIP() << "persistent kernel tests are currently disabled";
        }
      }
    };
    
    TEST_P(MLPBenchmarkTest, FwdGEMM) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = linear(tv0, tv1);
    
      fusion.addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto a_ref = at::randn({M, K}, options);
      auto b_ref = at::randn({N, K}, options);
      auto out_ref = at::linear(a_ref, b_ref);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffering_strategy = test_params.warp_specialization
          ? MatmulParams::CircularBufferingStrategy::WarpSpecialized
          : MatmulParams::CircularBufferingStrategy::Pipelined;
      mparams.tiling_strategy = test_params.persistent_kernel
          ? MatmulParams::TilingStrategy::DistributeTilesAcrossSMs
          : MatmulParams::TilingStrategy::OneTilePerCTA;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 2, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      // TODO Incorrect results because incorrect placement of wgmma syncs
      // EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_P(MLPBenchmarkTest, FwdEpilogueFusion) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = linear(tv0, tv1);
      fusion.addOutput(tv3);
    
      auto tv4 = castOp(DataType::Float, tv3);
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
      auto tv10 = mul(tv9, tv2);
      auto tv11 = castOp(DataType::BFloat16, tv10);
      fusion.addOutput(tv11);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto a_ref = at::randn({M, K}, options);
      auto b_ref = at::randn({N, K}, options);
      auto c_ref = at::randn({M, N}, options);
    
      auto tv3_ref = at::linear(a_ref, b_ref);
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv11_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * c_ref).to(at::kBFloat16);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.circular_buffering_strategy = test_params.warp_specialization
          ? MatmulParams::CircularBufferingStrategy::WarpSpecialized
          : MatmulParams::CircularBufferingStrategy::Pipelined;
      mparams.tiling_strategy = test_params.persistent_kernel
          ? MatmulParams::TilingStrategy::DistributeTilesAcrossSMs
          : MatmulParams::TilingStrategy::OneTilePerCTA;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 2, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref, c_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      // TODO Incorrect results because incorrect placement of wgmma syncs
      // EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
      // EXPECT_TRUE(cg_outputs[1].allclose(tv11_ref, 1e-2, 1e-2));
    }
    
    TEST_P(MLPBenchmarkTest, FwdHorizontalFusion) {
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = linear(tv0, tv1);
      fusion.addOutput(tv3);
    
      auto tv4 = castOp(DataType::Float, tv3);
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
    
      auto tv10 = linear(tv0, tv2);
      fusion.addOutput(tv10);
    
      auto tv11 = mul(tv9, tv10);
      auto tv12 = castOp(DataType::BFloat16, tv11);
      fusion.addOutput(tv12);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto a_ref = at::randn({M, K}, options);
      auto b_ref = at::randn({N, K}, options);
      auto c_ref = at::randn({N, K}, options);
    
      auto tv3_ref = at::linear(a_ref, b_ref);
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv10_ref = at::linear(a_ref, c_ref);
      auto tv12_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * tv10_ref.to(at::kFloat))
              .to(at::kBFloat16);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_128_16;
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 128, 64);
      gemm_tile.warp_tile = GemmTile(64, 128, 64);
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.circular_buffering_strategy = test_params.warp_specialization
          ? MatmulParams::CircularBufferingStrategy::WarpSpecialized
          : MatmulParams::CircularBufferingStrategy::Pipelined;
      mparams.tiling_strategy = test_params.persistent_kernel
          ? MatmulParams::TilingStrategy::DistributeTilesAcrossSMs
          : MatmulParams::TilingStrategy::OneTilePerCTA;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = true;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 2, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref, c_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // TODO Incorrect results because incorrect placement of wgmma syncs
      // TODO Incorrect results because of WAR hazard between aliased shared memory
      // between tv3 and tv12
      // Relax tolerance for larger sum due to large K
      // EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
      // EXPECT_TRUE(cg_outputs[1].allclose(tv10_ref, 1e-6 * K, 1e-6 * K));
      // EXPECT_TRUE(cg_outputs[2].allclose(tv12_ref, 1e-2, 1e-1));
    }
    TODO Items

    There are several TODO items in the test file that need to be addressed. Ensure that the cta and warp tile configurations are correctly set for Hopper and that the supported vector size is appropriate.

        // Create custom Matmul Params
        MatMulTileOptions gemm_tile;
        // TODO cta tile is a multiple of mma macro for hopper.
        // Default cta_tile configuration is 2-CTA.
        gemm_tile.cta_tile =
            GemmTile(2 * getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro));
    
        // TODO warp tile is (macroM, macroN, macroK) for hopper.
        gemm_tile.warp_tile =
            GemmTile(getM(mma_macro), getN(mma_macro), 2 * getK(mma_macro));
    
        mparams.supported_vec_size = {8, 8, 8};
    
        mparams.mma_macro = mma_macro;
    
        mparams.use_smem_epilogue = use_smem_epilogue;
    
        mparams.splitk_factor = splitk_factor;
        mparams.tile_sizes = gemm_tile;
        mparams.async_gmem_load_operands = true;
        mparams.circular_buffer_options.circular_buffer_smem_write = true;
        mparams.circular_buffer_options.circular_buffer_smem_read = true;
        mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
      }
    
      void TearDown() {
        if (testing::Test::IsSkipped() || testing::Test::HasFailure()) {
          return;
        }
    
        NVF_CHECK(
            1 == ir_utils::getOpsOfType<MmaOp>(fusion).size(),
            "matmul fusion must have exactly one MmaOp");
    
        // Schedule matmul fusion using custom parameters
        SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
            ->schedule(fusion, &mparams);
    
        KernelExecutor ke;
        ke.compile(fusion, inputs, LaunchParams(), matmul_cparams);
        auto nvf_out = ke.run(inputs);
        EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2));
      }
    
     protected:
      bool use_smem_epilogue;
      bool a_k_inner, b_k_inner;
      int64_t M, N, K;
      MmaMacro mma_macro;
      int64_t splitk_factor;
      std::unique_ptr<Fusion> fusion_up;
      Fusion* fusion;
      std::unique_ptr<FusionGuard> fusion_guard;
      DataType dtype = DataType::Half;
    
      MmaLayout layout;
    
      MatmulParams mparams;
    
      std::vector<c10::IValue> inputs;
    
      // Tests should place the reference tensor here
      at::Tensor tref;
    };
    
    TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
      const auto& [A, B] =
          matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
      inputs = {A, B};
    
      TensorView* tv0 = nullptr;
      TensorView* tv1 = nullptr;
      std::unordered_map<int64_t, int64_t> old2new;
      int64_t k_axis = 0;
    
      switch (layout) {
        case MmaLayout::TT:
          // Inner dims KN, order is MKN
          tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
          old2new = {{-2, -1}, {-1, -2}};
          k_axis = -2;
          break;
        case MmaLayout::TN:
          // Inner dims KK, order is MNK
          tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
          tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
          old2new = {};
          k_axis = -1;
          break;
        case MmaLayout::NT:
          // Inner dims MN, order is KMN
          tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
          old2new = {{-3, -1}};
          k_axis = -3;
          break;
        case MmaLayout::NN:
          // Inner dims MK, order is NKM
          tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
          tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          old2new = {{-1, -3}};
          k_axis = -2;
          break;
      }
    
      fusion->addInput(tv0);
      fusion->addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {k_axis});
    
      // Reorder the accumulator as [M, N, K]
      tv2->reorder(old2new);
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(dtype, tv2);
      fusion->addOutput(tv3);
    
      tref = atMatmul(A.squeeze(), B.squeeze(), layout);
    }
    
    // TODO: Remove this test once the architecture agnostic can be
    // run on hopper.
    TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) {
      const auto& [A, B] =
          matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
      const auto& C = matmulAtInput2D(
          layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K);
      inputs = {A, B, C};
    
      TensorView* tv0 = nullptr;
      TensorView* tv1 = nullptr;
      std::unordered_map<int64_t, int64_t> old2new;
      int64_t k_axis = 0;
    
      switch (layout) {
        case MmaLayout::TT:
          // Inner dims KN, order is MKN
          tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
          old2new = {{-2, -1}, {-1, -2}};
          k_axis = -2;
          break;
        case MmaLayout::TN:
          // Inner dims KK, order is MNK
          tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype);
          tv1 = makeContigConcreteTensor({1, -1, -1}, dtype);
          old2new = {};
          k_axis = -1;
          break;
        case MmaLayout::NT:
          // Inner dims MN, order is KMN
          tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
          old2new = {{-3, -1}};
          k_axis = -3;
          break;
        case MmaLayout::NN:
          // Inner dims MK, order is NKM
          tv0 = makeContigConcreteTensor({1, -1, -1}, dtype);
          tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype);
          old2new = {{-1, -3}};
          k_axis = -2;
          break;
      }
      TensorView* tv2 = makeContigConcreteTensor({-1}, dtype);
    
      fusion->addInput(tv0);
      fusion->addInput(tv1);
      fusion->addInput(tv2);
    
      auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis});
    
      // Reorder the accumulator as [M, N, K]
      tv3->reorder(old2new);
      tv3->commitLeafToLogical();
    
      auto* tv4 = maybeCastOp(DataType::Float, tv2);
      auto* tv5 = biasEpilogue(tv3, tv4);
      auto* tv6 = neg(tv5);
      auto* tv7 = castOp(dtype, tv6);
      fusion->addOutput(tv7);
    
      tref = atBiasEpilogue(
                 atMatmul(A.squeeze(), B.squeeze(), layout),
                 C.to(data_type_to_aten(DataType::Float)))
                 .neg_()
                 .to(data_type_to_aten(DataType::Half));
    }
    
    INSTANTIATE_TEST_SUITE_P(
        General,
        HopperMatmulSchedulerTest,
        testing::Combine(
            testing::Bool(), // use_smem_epilogue
            testing::Bool(), // a_k_inner
            testing::Bool(), // b_k_inner
            testing::Values(512), // M
            testing::Values(256), // N
            testing::Values(128), // K
            testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros
            testing::Values(1, 2) // SplitK Factor
            ),
        hopperTestName);

    @rdspring1 rdspring1 force-pushed the hopper_matmul_cta_k_fix branch from f0ffe15 to 0370585 Compare January 31, 2025 22:13
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 changed the title Fix matmul incorrect results when k dim for CTA tile is a multiple of 16 Issue multiple wgmma operations when CTA k dim is a multiple of 16 Feb 3, 2025
    @rdspring1
    Copy link
    Collaborator Author

    Future TODOs:

    1. Insert wgmma syncs correctly to get correct results.
    2. Fix inlining

    @rdspring1 rdspring1 merged commit 9dc94c0 into main Feb 3, 2025
    51 checks passed
    @rdspring1 rdspring1 deleted the hopper_matmul_cta_k_fix branch February 3, 2025 16:35
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants