Skip to content

Commit 22e8ce4

Browse files
committed
[CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK
1 parent 7363096 commit 22e8ce4

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ struct PerformanceResult
102102
struct KernelInstance
103103
{
104104
std::string name_;
105+
std::string dp_persistent_;
106+
std::string reduction_strategy_;
105107
GemmProblem problem_;
106108
PerformanceResult perf_result_;
107109

@@ -114,6 +116,8 @@ struct KernelInstance
114116
{
115117
os << "{\n"
116118
<< " \"name\": \"" << obj.name_ << "\",\n"
119+
<< " \"dp_persistent\": \"" << obj.dp_persistent_ << "\",\n"
120+
<< " \"reduction_strategy\": \"" << obj.reduction_strategy_ << "\",\n"
117121
<< " \"problem\": " << obj.problem_ << ",\n"
118122
<< " \"perf_result\": " << obj.perf_result_ << "\n"
119123
<< "}";

tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ int main(int argc, char* argv[])
160160
{
161161
auto [result, parser] = create_args(argc, argv);
162162
if(!result)
163+
{
164+
parser.print();
163165
return EXIT_FAILURE;
166+
}
164167

165168
benchmark_gemm_single(parser);
166169
return EXIT_SUCCESS;

tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,15 @@ class GemmProfiler
164164
const std::tuple<std::string, float>& kernel_run_result)
165165
{
166166
auto [name, avg_time] = kernel_run_result;
167+
auto dp_persistent =
168+
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
169+
auto reduction_strategy =
170+
SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic
171+
? "Atomic"
172+
: "Reduction";
167173

168-
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
174+
KernelInstance kernel_instance{
175+
name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}};
169176

170177
// compute performance metric
171178
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
@@ -244,21 +251,24 @@ class GemmProfiler
244251
file << "rocm_version,device_name,"
245252
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
246253
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
247-
<< "structured_sparsity," << "name,"
248-
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
254+
<< "structured_sparsity," << "dp_persistent," << "reduction_strategy,"
255+
<< "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
249256
}
250257

251-
const auto& problem = kernel_instance.problem_;
252-
const auto& name = kernel_instance.name_;
253-
const auto& perf = kernel_instance.perf_result_;
258+
const auto& problem = kernel_instance.problem_;
259+
const auto& name = kernel_instance.name_;
260+
const auto& dp_persistent = kernel_instance.dp_persistent_;
261+
const auto& reduction_strategy = kernel_instance.reduction_strategy_;
262+
const auto& perf = kernel_instance.perf_result_;
254263

255264
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
256265
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
257266
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
258267
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
259268
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
260269
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
261-
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
270+
<< "," << problem.structured_sparsity_ << "," << dp_persistent << ","
271+
<< reduction_strategy << "," << name << "," << std::fixed
262272
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
263273
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
264274
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)

0 commit comments

Comments
 (0)