diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 8ae2e4d4632..9768645d137 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2094,7 +2094,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { 2}; } else if (ir_utils::isStMatrixOp(ldst)) { NVF_ERROR( - ldst->out()->as()->getLogicalDomain().size() == 2, + ldst->out()->as()->getLogicalDomain().size() >= 2, "We only support 2D inputs stmatrix"); NVF_ERROR( diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 0210a946ada..54c65b372e3 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,25 +29,70 @@ namespace nvfuser { -void HopperMultipleMatmulScheduler::transformLikeMmaOutput(TensorView* tv) { +void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK( + TensorView* tv) { + NVF_ERROR(tv->axis(-1)->isReduction(), "Inner axis should be Reduction."); + // The input is originally block tiled so that the inner dims are the CTA tile + // size + // + // We split this into warp tiles then instruction tiles + // Original: [..., M, N, K] + tv->split(-3, params_->tile_sizes.warp_tile.m); + tv->split(-3, getM(params_->mma_macro)); + tv->split(-2, params_->tile_sizes.warp_tile.n); + tv->split(-2, getN(params_->mma_macro)); + // K dimension is present for mma_result + // We don't need to split by warp_tile.k, since we always have + // cta_tile.k == warp_tile.k + tv->split(-1, getK(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki] + tv->reorder({ + {-8, -8}, // Mo + {-7, -6}, // Mw + {-6, -3}, // Mi + {-5, -7}, // No + {-4, -5}, // Nw + {-3, -2}, // Ni + {-2, -4}, // Kw + {-1, -1}, // Ki + }); + // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki] + tv->merge(-8); + // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni] + tv->axis(-7)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki] +} + +void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK( + TensorView* tv) { NVF_ERROR( tv->domain()->loop().size() >= 4, - "transformLikeMmaOutput requires at least four iterDomains but ", + "transformLikeMmaOutputWithoutK requires at least four iterDomains but ", tv->toString(), " only has ", tv->domain()->loop().size(), "."); + NVF_ERROR( + !tv->axis(-1)->isReduction(), "Inner axis should not be Reduction."); - // Original: [..., Mo, No, Mi, Ni] + // The input is originally block tiled so that the inner dims are the CTA tile + // size + // Original: [..., M, N] + // We split this into warp tiles then instruction tiles + tv->split(-2, params_->tile_sizes.warp_tile.m); tv->split(-2, getM(params_->mma_macro)); + tv->split(-1, params_->tile_sizes.warp_tile.n); tv->split(-1, getN(params_->mma_macro)); - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - tv->reorder({{-3, -2}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - tv->merge(-4); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - tv->axis(-3)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni] + tv->reorder({ + {-3, -5}, + {-2, -3}, + }); + // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] + tv->merge(-6); + // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] + tv->axis(-5)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] } MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { @@ -365,6 +410,7 @@ void HopperMultipleMatmulScheduler::scheduleOperands() { const std::vector& smem_operands, MmaOperand operand_type) { blockTileTensors(smem_operands); + parallelizeBlocks(smem_operands); for (TensorView* tv : smem_operands) { if (params_->promote_prologue_smem_reuse) { tv->promoteReuse(); @@ -452,33 +498,7 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { splitk_sums_.push_back(splitk_sum); } - // Original: [..., Mo, No, Mi, Ni, Ki] - mma_result->split(-3, getM(params_->mma_macro)); - mma_result->split(-2, getN(params_->mma_macro)); - - // Split k dimension of warp tile only if it is larger than k dimension of - // mma macro. Inlining can be at incorrect position for circular buffering - // if a reduction iterDomain has iterDomain 1. - if (params_->tile_sizes.warp_tile.k > getK(params_->mma_macro)) { - mma_result->split(-1, getK(params_->mma_macro)); - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii, Kio, Kii] - mma_result->reorder({{-5, -4}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii, Kio, Kii] - mma_result->reorder({{-2, -4}}); - // After Reorder: [..., Mo, No, Mio, Nio, Kio, Mii, Nii, Kii] - mma_result->merge(-6); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - mma_result->axis(-5)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] - } else { - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - mma_result->reorder({{-4, -3}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - mma_result->merge(-5); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - mma_result->axis(-4)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] - } + transformLikeMmaOutputWithK(mma_result); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( mma_result->getLoopDomain()); @@ -514,7 +534,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // op. blockTileTensors({d}); parallelizeBlocks({d}); - transformLikeMmaOutput(d); + transformLikeMmaOutputWithoutK(d); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( d->getLoopDomain()); @@ -545,8 +565,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // tile is a multiple of the macro size because stmatrix stores results from // wgmma to shared memory. For maximum inlining and to reduce shared memory // usage, the tma tile is mma_macro size. - const int64_t tma_m = getM(params_->mma_macro); - const int64_t tma_n = getN(params_->mma_macro); + const int64_t tma_m = params_->tile_sizes.warp_tile.m; + const int64_t tma_n = params_->tile_sizes.warp_tile.n; fusion_->manage("st_matrix_m_tile", stmatrix_tile_m); fusion_->manage("st_matrix_n_tile", stmatrix_tile_n); @@ -594,7 +614,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { blockTileTensors(tvs_to_schedule); parallelizeBlocks(tvs_to_schedule); for (auto tv : tvs_to_schedule) { - transformLikeMmaOutput(tv); + transformLikeMmaOutputWithoutK(tv); } // Should not propagate if the dc is a mma output as the mma output has @@ -645,7 +665,7 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { for (TensorView* splitk_sum : splitk_sums_) { // Always use serial grid reduction for split-K sum splitk_sum->definition()->as()->requestSerialGridReduction(); - transformLikeMmaOutput(splitk_sum); + transformLikeMmaOutputWithoutK(splitk_sum); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( splitk_sum->getLoopDomain()); splitk_sum->setLoopDomain(s.as()); diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 4d91d65cbc0..380c691eb15 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -191,7 +191,12 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { // Schedule a block-tiled TensorView like mma output. // Why? WGMMA has a unique output format. TensorViews after the mma-result in // registers must respect this format for correctness. - void transformLikeMmaOutput(TensorView* tv); + // This version is meant to be used on the mma_result, which has a Reduction + // K axis. + void transformLikeMmaOutputWithK(TensorView* tv); + + // This is like the above method, but tv should not have any K dimension + void transformLikeMmaOutputWithoutK(TensorView* tv); private: std::vector canonical_dim_ordering_; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index db0d2856050..1f576dbd8e8 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4519,4 +4519,142 @@ INSTANTIATE_TEST_SUITE_P( return ss.str(); }); +// This tests that we can use a small instruction tile with a medium size +// warpgroup tile and a large CTA tile. +TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + auto tv3 = castOp(DataType::Half, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({K, M, 1}, options); + auto b_ref = at::randn({K, 1, N}, options); + auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf); + + MatMulTileOptions gemm_tile; + // Regardless of the instruction, this should result in 2 warp groups i.e. 256 + // threads + gemm_tile.cta_tile = GemmTile(256, 256, 32); + gemm_tile.warp_tile = GemmTile(128, 128, 32); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_64_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + // NOTE: disabling smem use for this test since we currrently hit a bank + // conflict. + // TODO: enable smem epilogue once stmatrix is updated + mparams.use_smem_epilogue = false; + mparams.cluster_dims = {2, 1, 1}; + mparams.promote_prologue_smem_reuse = false; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + kir::Kernel* kernel = ke.compiledKernel()->kernel(); + ASSERT_TRUE(kernel != nullptr); + EXPECT_TRUE(getBankConflictInfo(kernel).empty()); + EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); + + auto cg_outputs = ke.run(inputs); + + // Check number of launched threads matches what we expect + EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128); + EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4) + << " expected 4 warp groups (BIDy==4) but found BIDy==" + << ke.lastLaunchParams().bdimy(); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + +TEST_F(HopperMatmulTest, ScheduleWithTranslation) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K + auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N + // Note tv1 has allocation domain + // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = matmul(tv0, tv1); + + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto a_ref = at::randn({M, K}, options); + // auto b_ref = at::randn({N, K}, options).t(); + auto b_ref = at::randn({K, N}, options); + auto out_ref = at::matmul(a_ref, b_ref); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 16); + gemm_tile.warp_tile = GemmTile(64, 64, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_64_16; + mparams.tile_sizes = gemm_tile; + mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 3; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {1, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + std::vector inputs = {a_ref, b_ref}; + + KernelExecutor ke; + ke.compile(&fusion, inputs); + kir::Kernel* kernel = ke.compiledKernel()->kernel(); + ASSERT_TRUE(kernel != nullptr); + EXPECT_TRUE(getBankConflictInfo(kernel).empty()); + EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel)); + + auto cg_outputs = ke.run(inputs); + + // Relax tolerance for larger sum due to large K + EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser diff --git a/tests/cpp/test_translate_mma.cpp b/tests/cpp/test_translate_mma.cpp index 0128b629b76..ef94ca299d1 100644 --- a/tests/cpp/test_translate_mma.cpp +++ b/tests/cpp/test_translate_mma.cpp @@ -315,7 +315,7 @@ using MatmulNodeTranslationTest = // Test that a simple matmul op fusion is picked up by the appropriate scheduler // and the translation to MmaOp is performed properly. TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) { - NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 10, 0); const int64_t A_dim = std::get<0>(GetParam()); const int64_t B_dim = std::get<1>(GetParam()); const bool enable_fusion = std::get<2>(GetParam()); @@ -323,6 +323,11 @@ TEST_P(MatmulNodeTranslationTest, AutomaticSchedulerMatmulNode) { const bool expect_segmented = std::get<4>(GetParam()); const SchedulerType expected_heuristic = std::get<5>(GetParam()); + if (A_dim == 3 && B_dim == 2) { + // TODO: Fix the failure at checkConcreteStaticDim on Hopper in this case + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); + } + // CombineMulSumAsMmaTest disabled MatmulExprEval, but we need it enabled DisableOptionsGuard dog; DisableOptionsGuard::getCurOptions().unset(DisableOption::MatmulExprEval);