Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Hopper MMA by warp-tile before instruction tile #3642

Merged
merged 21 commits into from
Feb 6, 2025

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 24, 2024

Currently we ignore the warp tile parameter when scheduling Hopper matmuls (see #3636). This PR introduces a test with different CTA, warp, and instruction tiles and modifies the Hopper scheduler to split by warp tile in addition to instruction tile. Note that the instruction tile split results in two serial loop domain so we wind up executing multiple mma instructions in each main loop. In the included example, warp_tile is 64, 128, 16 and the macro is Hopper_64_8_16. In this case, there are 128/8 = 16 instruction tiles per warp tile so the generated main loop looks like this:

  #pragma unroll 3
  for(nvfuser_index_t i33 = 0; i33 < i4; ++i33) {
    nvfuser_index_t i34;
    i34 = 48 + (16 * i33);
    nvfuser_index_t i35;
    i35 = (3 + i33) % 4;
    unsigned i36;
    i36 = i7 + (8192 * i35);
    unsigned i37;
    i37 = i10 + (4096 * i35);
    nvfuser_index_t i38;
    i38 = i33 % 4;
    unsigned i39;
    i39 = i13 + (4096 * i38);
    uint64_t i40;
    i40 = 4611686293305294848ULL | ((262143ULL & (uint64_t)(i39)) >> 4ULL);
    unsigned i41;
    i41 = i15 + (8192 * i38);
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i31 = 0; i31 < 4; ++i31) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i31)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i36 + (2048 * i31)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 4096U);
      #pragma unroll
      for(nvfuser_index_t i32 = 0; i32 < 2; ++i32) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i32)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i37 + (2048 * i32)));
      }
    }
    mbarrier::waitParity(toSmem((&T8[(i33 % 4)])), (uint32_t)(((i33 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) {
      unsigned i42;
      i42 = (i41 + (2048 * (i25 / 8))) + (16 * (i25 % 8));
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %6, 0;\n"
        "  wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 {%0, %1, %2, %3}, %4, %5, p0, %7, %8, %9, %10;\n"
        "}\n"
        :"+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[0]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[1]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[2]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[3])
        :"l"(i40),
         "l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i42)) >> 4ULL))),
         "n"((uint32_t)(true)),
         "n"(1),
         "n"(1),
         "n"(1),
         "n"(1)
      );
    }
    __syncthreads();
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  }

Fixes #3636

Fixes #3636
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

The bank conflict came from stmatrix scheduling which needs to be updated. I will do that in a separate PR. For now, I've disabled smem epilogue in the included test.

@jacobhinkle jacobhinkle marked this pull request as ready for review December 31, 2024 13:47

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Dec 31, 2024

When I manually disable stmatrix but keep TMA store, I still hit a bank conflict and misaligned address in the smem read when doing the TMA store. The epilogue looks like this:

  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
    nvfuser_index_t i51;
    i51 = 4 * i50;
    #pragma unroll
    for(nvfuser_index_t i52 = 0; i52 < 2; ++i52) {
      nvfuser_index_t i53;
      i53 = i51 + (2 * i52);
      Array<__half, 2, 2> T6;
      #pragma unroll
      for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
        T6[i54]
           = __float2half(T2[(i53 + i54)]);
      }
      loadGeneric<__half, 2>( &T7[(i17 + (128 * i52))],  &T6[0]);
    }
    __syncthreads();
    asm volatile("fence.proxy.async;\n");
    if (b24) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i20 + (8 * i50)), i21}) }), i18);
    }
    __syncthreads();
    asm volatile("cp.async.bulk.commit_group;\n");
    asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
  }
  asm volatile("cp.async.bulk.commit_group;\n");
  asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");

The misaligned read happens with i20 = 1152, i50 = 0, i21 = 320, i18 = 3088. Note that we have

  threadIdx.y = 3;
  i11 = ((nvfuser_index_t)threadIdx.y) / 2; // =1
  i12 = 2048 * i11; // =2048
  i14 = ((nvfuser_index_t)threadIdx.y) % 2; // =1
  i18 = (toSmem(T7) + i12) + (16 * i14); // =toSmem(T7) + 2064

CUDA Exception: Warp Misaligned Address

@jacobhinkle
Copy link
Collaborator Author

mma result before this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 8 -> iS59{32}, iS60{8}
  Merge: iS57{2} and iS59{32} -> ithreadIdx.y61{64}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16})

And after this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 128 -> iS61{2}, iS62{128}
  Merge: iS57{2} and iS61{2} -> ithreadIdx.y65{4}
  Split: iS58{64} by factor 64 -> iS59{1}, iS60{64}
  Split: iS62{128} by factor 8 -> iS63{16}, iS64{8}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16})

@jacobhinkle
Copy link
Collaborator Author

Note that I can enable smem epilogue and the test passes if I use Hopper_64_64_16 and I disable stmatrix.

I think this covers the motivation for #3616
Comment on lines 47 to 49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 is this enough or is #3616 still needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

// size
// Original: [..., M, N(, K)]
// We split this into warp tiles then instruction tiles
if (is_mma_result) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: since there is no code in common between these branches, we should split this into two separate functions.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

Comment on lines 47 to 49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

@rdspring1
Copy link
Collaborator

rdspring1 commented Jan 2, 2025

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

@jacobhinkle
Copy link
Collaborator Author

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

I just checked this by modifying the test to expect 5 warpgroups instead of four and by adding a WarpSpecialized(ParallelType::TIDy) argument to these circularBuffer calls:

acw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);

The result for me is a passing test but perf drops.

@jacobhinkle
Copy link
Collaborator Author

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

I might be confused here. The thing is that the K dimension is treated differently from the M and N dimensions in these tile definitions. The instruction tile's K dimension is clear, and the warp tile's K dimension (I think) signifies how much data we should load at a time then we can loop to compute instructions over all the loaded data. The CTA tile's M and N dimensions specify the tiling of the output, but what does the cta_tile.k signify? This is why I was thinking we'd keep this restriction.

Note that this restriction cta_tile.k == warp_tile.k is enforced on Ampere as part of scheduleWarpTileWithReduction.

EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is pretty much identical to the previous one, but it uses a MatmulOp instead of fusedMultiplySum. This is currently failing (passes on main) with

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/pass/circular_buffer.cpp":160, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. No IfThenElse should exist yet:
IF ElectSync:
  MBarrierArriveExpectTx(T9_s[i408] view( T9 ), 4096)
  FOR i372 in iB28{16}:
    FOR i375 in iB34{2}:
      FOR i373 in iB31{4}:
        FOR i376 in iB35{2}:
          FOR i374 in iB33{8}:
            T3_s___half[iblockIdx.x24{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, bS22{1}, iS20{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, bS23{256}, iS26{1}, iB28{16}, iB34{2}, iB31{4}, iB35{2}, iB33{8}] ca_pos( 5 )
               = CpAsyncBulkTensorTile( T0_g___half[iS170{( (( (( getMetaData(T0) )).logical_size ))[0] )}, iS171{( (( (( getMetaData(T0) )).logical_size ))[1] )}] )

@jacobhinkle jacobhinkle added the on hold This issue should be revisited in the future label Jan 14, 2025
@jacobhinkle
Copy link
Collaborator Author

This is on hold temporarily while I investigate decoupling math warp groups by splitting by warp tile before the TMA/MMA scheduling. That would be a different approach that would let us schedule entire K loop of one math group before the next group's K loop, allowing some epilogue overlap between math groups in addition to overlapping the DMA warps.

@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Jan 17, 2025

PR Reviewer Guide 🔍

(Review updated until commit 9b5e73c)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Performance Impact

The changes in the transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK functions may have a significant impact on performance. It is crucial to thoroughly review and test these changes to ensure they do not introduce any regressions.

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
    TensorView* tv) {
  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  //
  // We split this into warp tiles then instruction tiles
  // Original: [..., M, N, K]
  tv->split(-3, params_->tile_sizes.warp_tile.m);
  tv->split(-3, getM(params_->mma_macro));
  tv->split(-2, params_->tile_sizes.warp_tile.n);
  tv->split(-2, getN(params_->mma_macro));
  // K dimension is present for mma_result
  // We don't need to split by warp_tile.k, since we always have
  // cta_tile.k==warp_tile.k
  tv->split(-1, getK(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
  tv->reorder({
      {-8, -8}, // Mo
      {-7, -6}, // Mw
      {-6, -3}, // Mi
      {-5, -7}, // No
      {-4, -5}, // Nw
      {-3, -2}, // Ni
      {-2, -4}, // Kw
      {-1, -1}, // Ki
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
  tv->merge(-8);
  // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
  tv->axis(-7)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
    TensorView* tv) {
  // TODO Add constraints

  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  // Original: [..., M, N]
  // We split this into warp tiles then instruction tiles
  tv->split(-2, params_->tile_sizes.warp_tile.m);
  tv->split(-2, getM(params_->mma_macro));
  tv->split(-1, params_->tile_sizes.warp_tile.n);
  tv->split(-1, getN(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
  tv->reorder({
      {-3, -5},
      {-2, -3},
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
  tv->merge(-6);
  // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
  tv->axis(-5)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}
Code Duplication

The transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK functions have similar code structures. Consider refactoring the code to reduce duplication and improve maintainability.

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
    TensorView* tv) {
  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  //
  // We split this into warp tiles then instruction tiles
  // Original: [..., M, N, K]
  tv->split(-3, params_->tile_sizes.warp_tile.m);
  tv->split(-3, getM(params_->mma_macro));
  tv->split(-2, params_->tile_sizes.warp_tile.n);
  tv->split(-2, getN(params_->mma_macro));
  // K dimension is present for mma_result
  // We don't need to split by warp_tile.k, since we always have
  // cta_tile.k==warp_tile.k
  tv->split(-1, getK(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
  tv->reorder({
      {-8, -8}, // Mo
      {-7, -6}, // Mw
      {-6, -3}, // Mi
      {-5, -7}, // No
      {-4, -5}, // Nw
      {-3, -2}, // Ni
      {-2, -4}, // Kw
      {-1, -1}, // Ki
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
  tv->merge(-8);
  // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
  tv->axis(-7)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
    TensorView* tv) {
  // TODO Add constraints

  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  // Original: [..., M, N]
  // We split this into warp tiles then instruction tiles
  tv->split(-2, params_->tile_sizes.warp_tile.m);
  tv->split(-2, getM(params_->mma_macro));
  tv->split(-1, params_->tile_sizes.warp_tile.n);
  tv->split(-1, getN(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
  tv->reorder({
      {-3, -5},
      {-2, -3},
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
  tv->merge(-6);
  // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
  tv->axis(-5)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}
Test Coverage

The added tests in HopperMatmulTest may not cover all possible scenarios. Review the test cases to ensure they are comprehensive and cover all the necessary edge cases.

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 2048, N = 2048, K = 8192;
  const auto dtype = DataType::Half;

  auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
  auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  auto tv2 = fusedMultiplySum(tv0, tv1, {0});

  // Reorder the accumulator as [M, N, K]
  // [K, M, N] -> [M, N, K]
  tv2->reorder({{-3, -1}});
  tv2->commitLeafToLogical();

  auto tv3 = castOp(DataType::Half, tv2);
  fusion.addOutput(tv3);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  auto a_ref = at::randn({K, M, 1}, options);
  auto b_ref = at::randn({K, 1, N}, options);
  auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

  MatMulTileOptions gemm_tile;
  // Regardless of the instruction, this should result in 2 warp groups i.e. 256
  // threads
  gemm_tile.cta_tile = GemmTile(256, 256, 32);
  gemm_tile.warp_tile = GemmTile(128, 128, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_64_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  // NOTE: disabling smem use for this test since we currrently hit a bank
  // conflict.
  // TODO: enable smem epilogue once stmatrix is updated
  mparams.use_smem_epilogue = false;
  mparams.cluster_dims = {2, 1, 1};
  mparams.promote_prologue_smem_reuse = false;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  std::vector<c10::IValue> inputs = {a_ref, b_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  kir::Kernel* kernel = ke.compiledKernel()->kernel();
  ASSERT_TRUE(kernel != nullptr);
  EXPECT_TRUE(getBankConflictInfo(kernel).empty());
  EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

  auto cg_outputs = ke.run(inputs);

  // Check number of launched threads matches what we expect
  EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
  EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
      << " expected 4 warp groups (BIDy==4) but found BIDy=="
      << ke.lastLaunchParams().bdimy();

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 2048, N = 2048, K = 8192;
  const auto dtype = DataType::Half;

  auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
  auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
  // Note tv1 has allocation domain
  // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  auto tv2 = matmul(tv0, tv1);

  fusion.addOutput(tv2);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  auto a_ref = at::randn({M, K}, options);
  // auto b_ref = at::randn({N, K}, options).t();
  auto b_ref = at::randn({K, N}, options);
  auto out_ref = at::matmul(a_ref, b_ref);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 256, 16);
  gemm_tile.warp_tile = GemmTile(64, 64, 16);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_64_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  mparams.use_smem_epilogue = true;
  mparams.cluster_dims = {1, 1, 1};
  mparams.promote_prologue_smem_reuse = true;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  std::vector<c10::IValue> inputs = {a_ref, b_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  kir::Kernel* kernel = ke.compiledKernel()->kernel();
  ASSERT_TRUE(kernel != nullptr);
  EXPECT_TRUE(getBankConflictInfo(kernel).empty());
  EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

  auto cg_outputs = ke.run(inputs);

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser

There is still one case that fails, which we should fix. I'll create an
issue for it.
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle removed the on hold This issue should be revisited in the future label Jan 29, 2025
@jacobhinkle jacobhinkle requested a review from rdspring1 January 29, 2025 21:47
@jacobhinkle
Copy link
Collaborator Author

!build

Copy link

github-actions bot commented Feb 3, 2025

Review updated until commit 1d4697e

Description

  • Split Hopper MMA by warp-tile before instruction tile

  • Introduce new methods for handling MMA outputs with and without K dimension

  • Add test cases for multiple instructions per warp tile and scheduler translation

  • Update CUDA architecture range for matmul node translation tests


Changes walkthrough 📝

Relevant files
Enhancement
index.cpp
Relax logical domain size check                                                   

csrc/device_lower/pass/index.cpp

  • Relax the logical domain size check for stmatrix operations
+1/-1     
hopper_multi_matmul.cpp
Update MMA output transformations                                               

csrc/scheduler/hopper_multi_matmul.cpp

  • Introduce transformLikeMmaOutputWithK for MMA outputs with K dimension
  • Introduce transformLikeMmaOutputWithoutK for MMA outputs without K
    dimension
  • Update scheduleMmaResults and scheduleEpilogue to use new methods
  • Update scheduleSplitKSum to use transformLikeMmaOutputWithoutK
  • +62/-42 
    hopper_multi_matmul.h
    Declare new MMA output transformation methods                       

    csrc/scheduler/hopper_multi_matmul.h

  • Declare new methods transformLikeMmaOutputWithK and
    transformLikeMmaOutputWithoutK
  • +6/-1     
    Tests
    test_matmul.cpp
    Add new matmul tests                                                                         

    tests/cpp/test_matmul.cpp

  • Add test for multiple instructions per warp tile
  • Add test for scheduler translation
  • +138/-0 
    test_translate_mma.cpp
    Update CUDA architecture range and add guard                         

    tests/cpp/test_translate_mma.cpp

  • Update CUDA architecture range for matmul node translation tests
  • Add guard for specific test case failure on Hopper
  • +6/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Concern

    The new transformations introduced in transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK should be evaluated for performance. Ensure that the performance gains from splitting by warp tile are significant and that there are no regressions.

    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
        TensorView* tv) {
      NVF_ERROR(tv->axis(-1)->isReduction(), "Inner axis should be Reduction.");
      // The input is originally block tiled so that the inner dims are the CTA tile
      // size
      //
      // We split this into warp tiles then instruction tiles
      // Original: [..., M, N, K]
      tv->split(-3, params_->tile_sizes.warp_tile.m);
      tv->split(-3, getM(params_->mma_macro));
      tv->split(-2, params_->tile_sizes.warp_tile.n);
      tv->split(-2, getN(params_->mma_macro));
      // K dimension is present for mma_result
      // We don't need to split by warp_tile.k, since we always have
      // cta_tile.k == warp_tile.k
      tv->split(-1, getK(params_->mma_macro));
      // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
      tv->reorder({
          {-8, -8}, // Mo
          {-7, -6}, // Mw
          {-6, -3}, // Mi
          {-5, -7}, // No
          {-4, -5}, // Nw
          {-3, -2}, // Ni
          {-2, -4}, // Kw
          {-1, -1}, // Ki
      });
      // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
      tv->merge(-8);
      // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
      tv->axis(-7)->parallelize(ParallelType::TIDy);
      // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
    }
    
    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
    Test Coverage

    The new tests added should cover a variety of scenarios and configurations to ensure the scheduler behaves correctly under different conditions. Consider adding more test cases with varying tile sizes and dimensions.

    TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
      auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {0});
    
      // Reorder the accumulator as [M, N, K]
      // [K, M, N] -> [M, N, K]
      tv2->reorder({{-3, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({K, M, 1}, options);
      auto b_ref = at::randn({K, 1, N}, options);
      auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      // Regardless of the instruction, this should result in 2 warp groups i.e. 256
      // threads
      gemm_tile.cta_tile = GemmTile(256, 256, 32);
      gemm_tile.warp_tile = GemmTile(128, 128, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_64_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      // NOTE: disabling smem use for this test since we currrently hit a bank
      // conflict.
      // TODO: enable smem epilogue once stmatrix is updated
      mparams.use_smem_epilogue = false;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = false;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      kir::Kernel* kernel = ke.compiledKernel()->kernel();
      ASSERT_TRUE(kernel != nullptr);
      EXPECT_TRUE(getBankConflictInfo(kernel).empty());
      EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));
    
      auto cg_outputs = ke.run(inputs);
    
      // Check number of launched threads matches what we expect
      EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
      EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
          << " expected 4 warp groups (BIDy==4) but found BIDy=="
          << ke.lastLaunchParams().bdimy();
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
      // Note tv1 has allocation domain
      // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = matmul(tv0, tv1);
    
      fusion.addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto a_ref = at::randn({M, K}, options);
      // auto b_ref = at::randn({N, K}, options).t();
      auto b_ref = at::randn({K, N}, options);
      auto out_ref = at::matmul(a_ref, b_ref);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 16);
      gemm_tile.warp_tile = GemmTile(64, 64, 16);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_64_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      kir::Kernel* kernel = ke.compiledKernel()->kernel();
      ASSERT_TRUE(kernel != nullptr);
      EXPECT_TRUE(getBankConflictInfo(kernel).empty());
      EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));
    
      auto cg_outputs = ke.run(inputs);
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    } // namespace nvfuser
    Code Complexity

    The new methods transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK introduce additional complexity. Ensure that the code is well-documented and that the logic is clear and maintainable.

    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
        TensorView* tv) {
      NVF_ERROR(tv->axis(-1)->isReduction(), "Inner axis should be Reduction.");
      // The input is originally block tiled so that the inner dims are the CTA tile
      // size
      //
      // We split this into warp tiles then instruction tiles
      // Original: [..., M, N, K]
      tv->split(-3, params_->tile_sizes.warp_tile.m);
      tv->split(-3, getM(params_->mma_macro));
      tv->split(-2, params_->tile_sizes.warp_tile.n);
      tv->split(-2, getN(params_->mma_macro));
      // K dimension is present for mma_result
      // We don't need to split by warp_tile.k, since we always have
      // cta_tile.k == warp_tile.k
      tv->split(-1, getK(params_->mma_macro));
      // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
      tv->reorder({
          {-8, -8}, // Mo
          {-7, -6}, // Mw
          {-6, -3}, // Mi
          {-5, -7}, // No
          {-4, -5}, // Nw
          {-3, -2}, // Ni
          {-2, -4}, // Kw
          {-1, -1}, // Ki
      });
      // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
      tv->merge(-8);
      // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
      tv->axis(-7)->parallelize(ParallelType::TIDy);
      // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
    }
    
    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(

    @@ -365,6 +401,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() {
    const std::vector<TensorView*>& smem_operands,
    MmaOperand operand_type) {
    blockTileTensors(smem_operands);
    parallelizeBlocks(smem_operands);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    We might as well also parallelize these. Note that we could just call this from blockTileTensors since we are always parallelizing outer dims for every tensor.

    @rdspring1
    Copy link
    Collaborator

    !test

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    csrc/scheduler/hopper_multi_matmul.cpp Outdated Show resolved Hide resolved
    csrc/scheduler/hopper_multi_matmul.cpp Outdated Show resolved Hide resolved
    // Original: [..., Mo, No, Mi, Ni]
    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
    TensorView* tv) {
    // The input is originally block tiled so that the inner dims are the CTA tile
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we have any conditions to check here?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Good point. I guess we should check that the inner dim is reduction at least.

    jacobhinkle and others added 3 commits February 5, 2025 16:29

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Co-authored-by: Ryan Spring <[email protected]>
    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    !build

    @jacobhinkle jacobhinkle merged commit b076a55 into main Feb 6, 2025
    14 of 15 checks passed
    @jacobhinkle jacobhinkle deleted the hopper_warptile_split branch February 6, 2025 00:36
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    Split by warp tile in Hopper matmul scheduler
    2 participants