diff --git a/example/ck_tile/40_streamk_gemm/README.md b/example/ck_tile/40_streamk_gemm/README.md index fe9eb0c4f8..0272b1fe97 100644 --- a/example/ck_tile/40_streamk_gemm/README.md +++ b/example/ck_tile/40_streamk_gemm/README.md @@ -22,8 +22,8 @@ args: -a_layout tensor A data layout (default: R) -b_layout tensor B data layout (default: C) -c_layout tensor C data layout (default: R) - -num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1) -reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic) + -persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0) -stride_a tensor A stride (default:0) -stride_b tensor B stride (default:0) -stride_c tensor C stride (default:0) diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index abcca7eaec..69095ca3d7 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -18,7 +18,6 @@ struct GemmConfigBase static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; - static constexpr bool Persistent = false; static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -27,12 +26,12 @@ struct GemmConfigBase static constexpr bool DoubleSmemBuffer = false; }; -template +template struct GemmConfigMemoryInterwave : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 32; + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 16; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = Persistent_; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; template @@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[]) .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Column by default") .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("num_sk_blocks", - "-1", - "number of Stream-K blocks. -1: chosen by algorithm, or user selected") .insert("reduction_strategy", "atomic", "strategy for storing results in C tensor - atomic/reduction") + .insert("persistent_dp", + "0", + "0. Non-persistent data-parallel section, 1 Fully persistent kernel.") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 6dd054ee11..17182d87dc 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat, bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy, - uint32_t num_sk_blocks) + ck_tile::StreamKReductionStrategy reduction_strategy) { - ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - stride_A, - stride_B, - stride_C, - reduction_strategy, - num_sk_blocks}; + ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy}; std::tuple ave_time_and_batch; @@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc, ck_tile::StreamKReductionStrategy reduction_strategy = get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); - uint32_t num_sk_blocks = static_cast(arg_parser.get_int("num_sk_blocks")); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); @@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc, n_warmup, n_repeat, flush_cache, - reduction_strategy, - num_sk_blocks); + reduction_strategy); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); @@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc, << " B_Type=" << DataTypeTraits::name << " C_Type=" << DataTypeTraits::name << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " - << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - bool pass = true; + bool pass = false; // Memory on host to store gpu reference result ck_tile::HostTensor c_m_n_ref( diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index 8ec409023d..e04cb00379 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -2,7 +2,6 @@ // SPDX-License-Identifier: MIT #include "gemm_utils.hpp" -#include "run_gemm_example.inc" #include "ck_tile/ops/common.hpp" template -std::tuple gemm(const ck_tile::StreamKHostArgs& args, +std::tuple gemm(const ck_tile::reboot::StreamKHostArgs& args, const ck_tile::stream_config& s) - { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -29,7 +27,8 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmConfig::PermuteA, GemmConfig::PermuteB>; - using TilePartitioner = ck_tile::StreamKTilePartitioner; + using TilePartitioner = + ck_tile::StreamKTilePartitioner_v2; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits gemm(const ck_tile::StreamKHostArgs& args, memory_operation.value, GemmConfig::NumWaveGroups>>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::reboot::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); dim3 grids = Kernel::GridSize(kargs.tile_partitioner); dim3 blocks = Kernel::BlockSize(); @@ -101,28 +104,28 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, << std::endl; } - // Function to clear the output C tensor results after each repetition of the kernel - auto clear_gemm_output = [&]() { + auto reset_data_buffers = [&]() { if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + { + // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + } + else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + { + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); + } }; - std::function preprocess = clear_gemm_output; + std::function preprocess = reset_data_buffers; float ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile( - kargs.tile_partitioner.sk_num_blocks, - // k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are - // big and each does one iteration. Thus, we ensure the value passed in is at least 1 to - // avoid division by zero errors. - ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u), - kargs.tile_partitioner.k_iters_per_tile.get()); - + ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); return std::tuple{ave_time, num_wgs_per_tile}; }; @@ -145,6 +148,8 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, } } +#include "run_gemm_example.inc" + template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { @@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a return 0; } -template