Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/40_streamk_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ args:
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)
Expand Down
18 changes: 9 additions & 9 deletions example/ck_tile/40_streamk_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct GemmConfigBase

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;

static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
Expand All @@ -27,12 +26,12 @@ struct GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
};

template <typename PrecType>
template <typename PrecType, bool Persistent_>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 16;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
Expand All @@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;

static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool Persistent = Persistent_;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};

template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
Expand Down Expand Up @@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[])
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("num_sk_blocks",
"-1",
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/reduction")
.insert("persistent_dp",
"0",
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
Expand Down
34 changes: 15 additions & 19 deletions example/ck_tile/40_streamk_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy,
num_sk_blocks};
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};

std::tuple<float, ck_tile::index_t> ave_time_and_batch;

Expand Down Expand Up @@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,

ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
Expand Down Expand Up @@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy,
num_sk_blocks);
reduction_strategy);

c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

Expand All @@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;

bool pass = true;
bool pass = false;

// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(
Expand Down
90 changes: 64 additions & 26 deletions example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT

#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "ck_tile/ops/common.hpp"

template <typename GemmConfig,
Expand All @@ -17,9 +16,8 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
const ck_tile::stream_config& s)

{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
Expand All @@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;

using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;

using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
Expand Down Expand Up @@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
memory_operation.value,
GemmConfig::NumWaveGroups>>;

using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;

auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();

dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
Expand All @@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
<< std::endl;
}

// Function to clear the output C tensor results after each repetition of the kernel
auto clear_gemm_output = [&]() {
auto reset_data_buffers = [&]() {
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
}
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};

std::function<void()> preprocess = clear_gemm_output;
std::function<void()> preprocess = reset_data_buffers;

float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
kargs.tile_partitioner.sk_num_blocks,
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
// avoid division by zero errors.
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
kargs.tile_partitioner.k_iters_per_tile.get());

ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{ave_time, num_wgs_per_tile};
};

Expand All @@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
}
}

#include "run_gemm_example.inc"

template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
Expand All @@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}

template <template <typename PreType> typename GemmConfig>
template <template <typename PreType, bool Persistent_> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
Expand All @@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_dp = arg_parser.get_bool("persistent_dp");

if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else
{
Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/host/kernel_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
{
for(int i = 0; i < s.cold_niters_; i++)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
callables_func();
}
// Only profile preprocess if it's provided
Expand Down
Loading