Skip to content

[CK_TILE] Support for elementwise kernel #2246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
44dcdad
Elementwise kernel implementation
Apr 30, 2025
4d41cc2
Elementwise with generalized nDims
amd-yashagar May 12, 2025
b2de0b1
Adding the n-ary input tensor feature
May 19, 2025
685f915
Generalize dimensions on top of inputs
amd-yashagar May 19, 2025
8645970
Add TFLOPS + remove std usage for tuples
May 19, 2025
984198c
1D basecase optimization
msaffari-amd May 20, 2025
e149526
Cleanup code + refactoring to a common interface
May 20, 2025
f0f1c5a
Generalize to unary and add an example
SamiAario-AMD May 21, 2025
55fe333
Cleanup, refactoring and commenting
May 27, 2025
5b225a1
Suggestions for LWPCK-3170: elementwise kernel improvements
AviralGoelAMD May 27, 2025
bc8e9ae
Clang-format: remod.py
amd-yashagar May 28, 2025
b885fb2
Replace InputTensorType with XDataType as the type of input_tensors
SamiAario-AMD May 28, 2025
8c5d714
Add Tuple::apply and use it in ElementWiseKernel::operator to call op…
SamiAario-AMD May 28, 2025
4145fbf
Move examples to folder 19_elementwise
SamiAario-AMD May 30, 2025
bcccb72
Add missing copyright headers and fix some existing ones
SamiAario-AMD May 30, 2025
05bbba7
Replace an assert with throw std::runtime_error in elementwise example
SamiAario-AMD May 30, 2025
2c46907
Avoid reading the output by using make_static_distributed_tensor for …
SamiAario-AMD Jun 2, 2025
abd0d35
Removed two unused includes
SamiAario-AMD Jun 2, 2025
876b057
No need to move windows to the next block when each workgroup process…
SamiAario-AMD Jun 2, 2025
64a7d3f
Only copy input tensors to the device
SamiAario-AMD Jun 2, 2025
a33a41e
Use get_warp_size to obtain warp size, and use ceiling division for g…
SamiAario-AMD Jun 2, 2025
b52400c
Adding output strides to the kernel, transposition example and update…
Jun 2, 2025
891f701
Changes made by remod.py
SamiAario-AMD Jun 2, 2025
86519ac
Use default template parameter values for memory operation and cohere…
SamiAario-AMD Jun 2, 2025
fc1ff7f
Move binary operations to include/ck_tile/ops/elementwise/binary_elem…
SamiAario-AMD Jun 2, 2025
1046ff9
Reuse generic reference binary/unary operation in examples + refactor…
Jun 3, 2025
821019d
Fix comments in elementwise_example.cpp
SamiAario-AMD Jun 3, 2025
96d1a6b
Simplify CMakeLists.txt and remove the unused variables this uncovers
SamiAario-AMD Jun 4, 2025
99eac8e
Rename a file and fix some copyright statements
SamiAario-AMD Jun 4, 2025
6cbf189
Changes made by script/clang-format-overwrite.sh
SamiAario-AMD Jun 4, 2025
2a30295
Add basic unit test for ElementWiseKernel
Jun 4, 2025
c1c7af8
Merge branch 'develop' into LWPCK-3170-elementwise-kernel-general
SamiAario-AMD Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added logit soft-capping support for fMHA forward kernels.
* Added benchmarking support for tile engine GEMM.
* Added rotating buffer feature for CK_Tile GEMM.
* Added support for elementwise kernel.

### Optimized

Expand Down
10 changes: 10 additions & 0 deletions example/ck_tile/19_elementwise/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Elementwise example targets 2D inputs
set(TARGET_NAME_2D_INPUT tile_example_elementwise)
add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp)

# Elementwise unary example targets 2D inputs
set(TARGET_NAME_2D_INPUT_UNARY tile_example_elementwise_unary)
add_executable(${TARGET_NAME_2D_INPUT_UNARY} elementwise_example_unary.cpp)

set(TARGET_NAME_2D_INPUT_TRANSPOSE tile_example_elementwise_transpose)
add_executable(${TARGET_NAME_2D_INPUT_TRANSPOSE} elementwise_example_transpose.cpp)
211 changes: 211 additions & 0 deletions example/ck_tile/19_elementwise/elementwise_example.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "1024", "m dimension")
.insert("n", "1024", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");

// If stride is negative (default -1), set it to N, assuming a dense row-major layout.
if(stride < 0)
stride = N;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");

if(stride < N)
{
throw std::runtime_error("stride must be >= N");
}

// Define type aliases for clarity.
// XDataType: Data type of the input tensors.
// ComputeDataType: Data type used for intermediate computations (often float for precision).
// YDataType: Data type of the output tensor.
// XElementwiseOperation: The specific elementwise operation to perform (e.g., Add, Mul).
using XDataType = DataType;
using ComputeDataType =
float; // Using float for intermediate calculations can improve numerical stability.
using YDataType = DataType;
using XElementwiseOperation = ck_tile::element_wise::Add;

// 1. Initialize the input data on the host (CPU).
// HostTensor is a utility to manage tensor data on the CPU.
// The first argument is the shape (dimensions) of the tensor {M, N}.
// The second argument is the strides {stride, 1} for row-major layout.
// 'x_host_a' and 'x_host_b' are the two input tensors for the elementwise operation.
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_b({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_validation({M, N}, {stride, 1});

std::vector<ck_tile::index_t> shape = {M, N};

// Fill the host tensors with random data.
// FillUniformDistribution populates the tensor with values from a uniform distribution,
// within an interval.
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_b);

// 2. Create device memory buffers
// DeviceMem allocates memory on the GPU.
// The size is determined by the total number of elements and the size of DataType.
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());

// Copy data from host input tensors to device buffers.
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());

// 3. Configure the kernel execution parameters.
// Dividing the problem into blocktile, warptile, and vector
// The blocktile is the size of the tile processed by a single work group (also called thread
// block). The warptile is the size of the tile processed by a single wavefront (also called
// warp). The vector is the size of the tile processed by a single work item (also called
// thread). The problem is divided into blocks of size BlockTile. Each block is further divided
// into wavefronts of size WarpTile. Each wavefront is composed of 64 work items (on AMD; 32
// threads on NVIDIA). Each work item in a wavefront processes one vector's worth of elements.
// Note that WarpTile/Vector should be 64 for CDNA (because there are 64 work items per
// wavefront).
using BlockTile = ck_tile::sequence<2048>; // How many elements are handled by a block tile (the
// tensor is divided into blocks of this size)
using BlockWarps = ck_tile::sequence<8>; // How many concurrent wavefronts are in a block (each
// wavefront will cover some part of the block tile)

// WarpTile: Defines the size of the data sub-tile processed by a single wavefront.
// This should be consistent with BlockTile and BlockWarps.
// If BlockTile is 2048 and BlockWarps is 8, then WarpTile could be 2048/8 = 256.
// However, this example uses 64, meaning each wavefront processes 64 elements, and multiple
// such wavefront operations might be needed to cover the BlockTile, or the BlockTile is
// distributed differently.
// The current configuration (BlockTile=2048, BlockWarps=8, WarpTile=64) implies that
// each wavefront processes 64 elements, and 8 wavefronts process 8*64 = 512 elements
// concurrently. Since 512 is not equal to 2048, it means that warptile(s) will need to iterate
// over multiple times over different set of elements to cover the entire BlockTile.
using WarpTile = ck_tile::sequence<64>;

// Vector: Defines the number of elements processed by a single work item in one operation.
// If Vector is sequence<1>, each work item handles one element at a time from its assigned
// WarpTile portion. If WarpTile is 64 and warpSize is 64 (common), then each work item in the
// wavefront processes one element. If Vector is > 1, it implies vectorized load/store/compute
// operations per work item.
using Vector = ck_tile::sequence<1>;

// 4. Create the kernel

// ElementWiseShape bundles these tiling parameters.
// It calculates derived properties like threads per wavefront, repeats, and total block size.
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, Vector>;

// ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel:
// - Data types (input, compute, output).
// - Shape traits (tiling configuration).
// - The specific elementwise operation (e.g., Add).
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
XElementwiseOperation>;

// ElementWiseKernel refers to the GPU kernel class
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;

// Compute flattened size
ck_tile::index_t total_elements = 1;
for(auto d : shape)
total_elements *= d;

// kBlockSize: The number of work items in a GPU workgroup (thread block).
// This is often a multiple of the wavefront size, 64 on CDNA.
// Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize.
// Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512).
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});

// kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU).
// This can influence occupancy and performance.
constexpr ck_tile::index_t kBlockPerCu = 1;

// kGridSize: Calculates the total number of workgroups required to process all elements.
// Each workgroup is responsible for 'elements_per_block' elements.
// To ensure all elements are covered, especially when 'total_elements' is not perfectly
// divisible by 'elements_per_block', using ceiling division.
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;

std::cout << "grid size = " << kGridSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;

auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()));

// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
ck_tile::make_tuple(M, N), // Input size
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));

std::cout << "Average time: " << ave_time << " ms" << std::endl;

// 5. Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data());
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };

ck_tile::reference_binary_elementwise<XDataType, XDataType, YDataType, ComputeDataType>(
x_host_a, x_host_b, y_host, op);

pass = ck_tile::check_err(
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
}

return pass;
}

int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}

return -3;
}
150 changes: 150 additions & 0 deletions example/ck_tile/19_elementwise/elementwise_example_transpose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "1024", "m dimension of input")
.insert("n", "1024", "n dimension of input")
.insert("stride_in", "-1", "stride for input M dim, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t stride_in = arg_parser.get_int("stride_in");

if(stride_in < 0)
stride_in = N; // Dense input: stride for M dim is N
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");

if(stride_in < N)
{
throw std::runtime_error("stride_in must be >= N");
}

using XDataType = DataType;
using ComputeDataType = float;
using YDataType = DataType;
// Use PassThrough operation for transposition (data is moved, not changed)
using XElementwiseOperation = ck_tile::element_wise::PassThrough;

// 1. Initialize the input data on the host (CPU).
// Input x_host_a: M x N
// Output y_host: N x M (transposed)
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride_in, 1});
// Output tensor y_host will have dimensions N x M.
// Assuming dense output, its stride for the N dimension will be M.
ck_tile::index_t stride_out_dim0 = M;
ck_tile::HostTensor<YDataType> y_host({N, M}, {stride_out_dim0, 1});
ck_tile::HostTensor<YDataType> y_validation({N, M}, {stride_out_dim0, 1});

// The logical shape for the element-wise operation kernel is based on the input tensor's
// elements.
std::vector<ck_tile::index_t> op_shape_vec = {M, N};
auto op_lengths = ck_tile::make_tuple(M, N); // Lens for the kernel

ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);

// 2. Create device memory buffers
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); // y_host is N x M

x_buf_a.ToDevice(x_host_a.data());

// 3. Configure the kernel execution parameters.
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using Vector = ck_tile::sequence<1>;

using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, Vector>;

// Problem definition for a single input tensor
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
XElementwiseOperation>;

using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;

ck_tile::index_t total_elements = M * N;

constexpr ck_tile::index_t kBlockSize = 64 * BlockWarps::at(ck_tile::number<0>{});
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;

std::cout << "Input M=" << M << ", N=" << N << ", StrideIn=" << stride_in << std::endl;
std::cout << "Output N=" << N << ", M=" << M << ", StrideOut=" << stride_out_dim0 << std::endl;
std::cout << "Grid size = " << kGridSize << ", BlockSize = " << kBlockSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;

// Input tensors tuple (single input)
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()));
// Input strides tuple (tuple of tuples, one for each input)
auto input_strides = ck_tile::make_tuple(stride_in, 1);
// Output strides (for N x M tensor, dense)
auto output_strides = ck_tile::make_tuple(1, stride_out_dim0);

// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));

std::cout << "Average time: " << ave_time << " ms" << std::endl;

// 5. Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data()); // Copy result from device to y_validation
ck_tile::reference_transpose_elementwise<XDataType, YDataType>(
x_host_a, y_host); // Compute reference on host
pass = ck_tile::check_err(
y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01);
}

return pass;
}

int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}

std::cerr << "Unsupported data type: " << data_type << std::endl;
return -3;
}
Loading