Skip to content

Commit 6c2ca12

Browse files
vpietila-amdshumwayCopilotJH-Leon-KIM-AMD
authored
[CK_BUILDER] First fwd convolution builder implementation (#3070)
* Add experimental builder infrastructure for composable_kernel - Add experimental/builder directory with README documentation. - Create initial test infrastructure with CMakeLists.txt and placeholder test. - Update root CMakeLists.txt to support CK_EXPERIMENTAL_BUILDER option. - Update .gitignore to not treat `experimental/builder` as a CMake build directory. This establishes the directory structure for a high-level builder pattern that will provide a semantically-clear interface for constructing CK operations, with initial focus on convolution kernels for MIOpen integration. * Fix clang formatting. * Fix CMake build infrastructure for experimental builder - Add experimental/builder CMakeLists.txt with proper subdirectory structure - Add placeholder include/ck_tile/builder CMakeLists.txt for header installation - Fix gtest.cmake to use include_guard to prevent multiple inclusions - Update root CMakeLists.txt to include full builder directory instead of just tests * Scope C++20 settingto the test code Co-authored-by: Copilot <[email protected]> * Remove redundant GTest::gtest linkage Co-authored-by: Copilot <[email protected]> * Introduce basic types, and convolution algorithm concepts and limits. * Add convolution signature concepts. * Add convolution factory. * Finalize conv factory implementation for fwd convolutions. * Add type definitions for testing. * Add placeholder test. * Add convolution builder definition. * Fully functional fwd conv builder. * Test improvements. * Clean-up include headers. * Enable the limit checks for the convolution algorithm parameters. * Remove dead code. * clang formatting. * Add more tests and missing conv specialization argument. * clang formatting. * Add explicit handling of the tensor layouts. * Add complete 2D/3D layout support to CK Builder - Add missing 2D layouts: GNHWC_GKYXC_GNHWK, NGCHW_GKCYX_NGKHW - Add missing 3D layout: GNDHWC_GKZYXC_GNDHWK - Add 1D layouts (NWGC, NGCW, GNWC, NGCW_GKCX) for future support - Add 3 tests for new 2D/3D layouts - All tests pass (5/5) * Add tests for remaining 2D/3D layouts - Add test for 2D NGCHW_GKYXC_NGKHW (channels-first) with Filter1x1Stride1Pad0 - Add test for 3D NDHWGC_GKZYXC_NDHWGK (channels-last) - All 7 tests pass (complete coverage for all 2D/3D forward layouts) * Change enum converters to consteval. * 7 tests with pipeline and specialization| Test # | Dim | Type | Layout | Pipeline | Specialization | |--------|-----|------|----------------------|----------|-------------------------| | 1 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V1 | DEFAULT | | 2 | 2D | FP16 | GNHWC_GKYXC_GNHWK | V3 | FILTER_1X1_PAD0 | | 3 | 2D | FP32 | NGCHW_GKCYX_NGKHW | V4 | FILTER_1X1_STRIDE1_PAD0 | | 4 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V5 | FILTER_3x3 | | 5 | 3D | FP32 | NGCDHW_GKCZYX_NGKDHW | V1 | FILTER_1X1_PAD0 | | 6 | 3D | BF16 | GNDHWC_GKZYXC_GNDHWK | V3 | DEFAULT | | 7 | 3D | FP16 | NDHWGC_GKZYXC_NDHWGK | V4 | FILTER_1X1_PAD0 | * Add missing convolution layouts and provide better compile-time error in instance traits. * Fix clang formatting. * Changed I8 -> S8. * Fix signature. * Rename concepts and corresponding members. * Rename LDS related parameters. * Remove ODD_C specialization. Add V2 pipeline. * Add missing types. * Add elementwise operation to the conv signature. * Improve compile-time error message for unsupported elementwise ops. * Separate different fwd conv builder tests into separate compilation units. * Fix layout to string and add name to old CK PassThrough elementwise op. * Enable both CK and CK Tile tensor layouts in instance traits. * Fix clang-format. --------- Co-authored-by: John Shumway <[email protected]> Co-authored-by: John Shumway <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: JH-Leon-KIM-AMD <[email protected]>
1 parent 5c19740 commit 6c2ca12

21 files changed

+1527
-3
lines changed

experimental/builder/README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
2323

2424
To enable the experimental builder, configure your build with:
2525

26-
```sh
27-
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
26+
```bash
27+
cmake \
28+
-D CMAKE_PREFIX_PATH=/opt/rocm \
29+
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
30+
-D CMAKE_BUILD_TYPE=Release \
31+
-D GPU_TARGETS="gfx942;gfx950" \
32+
-D CK_EXPERIMENTAL_BUILDER=ON \
33+
-D CMAKE_CXX_STANDARD=20 \
34+
-G Ninja \
35+
..
2836
```
37+
2938
## Building and testing
3039

3140
During development, build and test from the CK build directory with
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include "ck/utility/sequence.hpp"
7+
#include "ck_tile/builder/types.hpp"
8+
9+
namespace ck_tile::builder {
10+
11+
// Convert a static array to a sequence
12+
// Usage example:
13+
// static constexpr std::vector arr {1, 2, 3};
14+
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
15+
template <typename T, const T& Arr>
16+
struct to_sequence_t
17+
{
18+
private:
19+
template <std::size_t... Is>
20+
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
21+
22+
// Helper method to handler the unusual .Size() method name in ck::Array.
23+
static constexpr auto get_size(const auto& arr)
24+
{
25+
if constexpr(requires { arr.size(); })
26+
{
27+
return arr.size();
28+
}
29+
else
30+
{
31+
return arr.Size();
32+
}
33+
}
34+
35+
public:
36+
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
37+
};
38+
39+
template <auto& Arr>
40+
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
41+
42+
// Wrapper function to make constexpr strings a structural type for NTTP.
43+
template <size_t N>
44+
struct StringLiteral
45+
{
46+
char data[N];
47+
constexpr StringLiteral(const char (&str)[N])
48+
{
49+
for(size_t i = 0; i < N; ++i)
50+
data[i] = str[i];
51+
}
52+
53+
constexpr bool operator==(const StringLiteral<N>& other) const
54+
{
55+
for(size_t i = 0; i < N; ++i)
56+
{
57+
if(data[i] != other.data[i])
58+
{
59+
return false;
60+
}
61+
}
62+
return true;
63+
}
64+
};
65+
66+
// This is a C++17 deduction guide. It allows the compiler to automatically
67+
// deduce the template argument `N` for `StringLiteral` from a string literal
68+
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
69+
// instead of `StringLiteral<4> s{"foo"};`.
70+
template <size_t N>
71+
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
72+
73+
// Helper to provide a readable error for unsupported enum values.
74+
// The compiler will print the name of this struct in the error message, so
75+
// the name of the enum value will appear instead of just its integer value.
76+
template <auto T>
77+
struct UnsupportedEnumValue
78+
{
79+
};
80+
81+
// Helper functions to convert enums to strings
82+
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
83+
{
84+
switch(dir)
85+
{
86+
case ConvDirection::FORWARD: return "Forward";
87+
case ConvDirection::BACKWARD_DATA: return "Backward Data";
88+
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
89+
default: return "Unknown";
90+
}
91+
}
92+
93+
constexpr std::string_view DataTypeToString(DataType dt)
94+
{
95+
switch(dt)
96+
{
97+
case DataType::FP16: return "FP16";
98+
case DataType::FP32: return "FP32";
99+
case DataType::BF16: return "BF16";
100+
case DataType::FP8: return "FP8";
101+
case DataType::I8: return "I8";
102+
case DataType::U8: return "U8";
103+
default: return "Unknown";
104+
}
105+
}
106+
107+
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
108+
{
109+
switch(layout)
110+
{
111+
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
112+
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
113+
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
114+
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
115+
default: return "Unknown";
116+
}
117+
}
118+
119+
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
120+
{
121+
switch(layout)
122+
{
123+
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
124+
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
125+
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
126+
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
127+
default: return "Unknown";
128+
}
129+
}
130+
131+
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
132+
{
133+
switch(layout)
134+
{
135+
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
136+
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
137+
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
138+
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
139+
default: return "Unknown";
140+
}
141+
}
142+
143+
} // namespace ck_tile::builder
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <type_traits>
7+
#include <concepts>
8+
#include <array>
9+
10+
#include "ck_tile/builder/types.hpp"
11+
12+
namespace ck_tile::builder {
13+
14+
/********************************************************************/
15+
/* Descriptors for individual elements of the algorithm description */
16+
/********************************************************************/
17+
18+
// Concept for thread block dimensions for a GEMM problem.
19+
template <typename T>
20+
concept ThreadBlockDescriptor = requires(T t) {
21+
{ t.block_size } -> std::convertible_to<size_t>;
22+
{ t.tile_size.m } -> std::convertible_to<size_t>;
23+
{ t.tile_size.n } -> std::convertible_to<size_t>;
24+
{ t.tile_size.k } -> std::convertible_to<size_t>;
25+
};
26+
27+
// Concept for parameters that describe a gridwise GEMM problem.
28+
template <typename T>
29+
concept GridwiseGemmDescriptor = requires(T t) {
30+
{ t.ak1 } -> std::convertible_to<size_t>;
31+
{ t.bk1 } -> std::convertible_to<size_t>;
32+
{ t.m_per_xdl } -> std::convertible_to<size_t>;
33+
{ t.n_per_xdl } -> std::convertible_to<size_t>;
34+
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
35+
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
36+
};
37+
38+
// Concept for vectorized data transfer for convolution input tensors.
39+
template <typename T>
40+
concept BlockTransferDescriptor = requires(T t) {
41+
{ t.k0 } -> std::convertible_to<size_t>;
42+
{ t.m_n } -> std::convertible_to<size_t>;
43+
{ t.k1 } -> std::convertible_to<size_t>;
44+
};
45+
46+
// Concept for thread cluster dimensions for GEMM output tensor.
47+
template <typename T>
48+
concept ThreadClusterDescriptor = requires(T t) {
49+
{ t.m_block } -> std::convertible_to<size_t>;
50+
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
51+
{ t.n_block } -> std::convertible_to<size_t>;
52+
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
53+
};
54+
55+
// Concept for the LDS transfer for the convolution input tensors.
56+
template <typename T>
57+
concept LdsTransferDescriptor = requires(T t) {
58+
{ t.src_vector_dim } -> std::convertible_to<size_t>;
59+
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
60+
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
61+
{ t.is_direct_load } -> std::convertible_to<bool>;
62+
{ t.lds_padding } -> std::convertible_to<bool>;
63+
};
64+
65+
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
66+
// LDS).
67+
template <typename T>
68+
concept EpilogueDescriptor = requires(T t) {
69+
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
70+
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
71+
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
72+
};
73+
74+
// Concept for the thread cluster access order
75+
template <typename T>
76+
concept AccessOrderDescriptor = requires(T t) {
77+
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
78+
};
79+
80+
// No requirements yet for a ConvAlogorithm concept.
81+
template <typename T>
82+
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
83+
84+
/******************************************** */
85+
/* Requirements for the algorithm description */
86+
/******************************************** */
87+
88+
// Concept to check if struct specifies thread block info.
89+
template <typename T>
90+
concept SpecifiesThreadBlock = requires {
91+
{ T::thread_block } -> ThreadBlockDescriptor;
92+
};
93+
94+
// Concept to check if a struct specifies gridwise GEMM info.
95+
template <typename T>
96+
concept SpecifiesGridwiseGemm = requires {
97+
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
98+
};
99+
100+
// Concept to check if a struct specifies convolution input and output block transfer info.
101+
template <typename T>
102+
concept SpecifiesBlockTransfer = requires(T t) {
103+
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
104+
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
105+
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
106+
};
107+
108+
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
109+
template <typename T>
110+
concept SpecifiesLdsTransfer = requires(T t) {
111+
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
112+
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
113+
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
114+
};
115+
116+
// Concept to check if a struct specifies thread cluster access order info.
117+
template <typename T>
118+
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
119+
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
120+
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
121+
};
122+
123+
// Concept to check if a struct specifies source access order info.
124+
template <typename T>
125+
concept SpecifiesSourceAccessOrder = requires(T t) {
126+
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
127+
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
128+
};
129+
130+
// Concept to check if struct specifies block_gemm_pipeline_version.
131+
template <typename T>
132+
concept SpecifiesGemmPipelineVersion = requires {
133+
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
134+
};
135+
136+
template <typename T>
137+
concept SpecifiesFwdConcSpecialization = requires {
138+
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
139+
};
140+
141+
} // namespace ck_tile::builder
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <type_traits>
7+
#include <concepts>
8+
9+
namespace ck_tile::builder {
10+
11+
// Limits for input vector transfer.
12+
template <auto Value>
13+
concept InputVectorTransferLimits = requires {
14+
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
15+
Value.lds_dst_scalar_per_vector > 0;
16+
};
17+
18+
// Limits for output vector transfer.
19+
template <auto Value>
20+
concept OutputVectorTransferLimits = requires {
21+
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
22+
Value.n_xdl_per_wave_per_shuffle > 0;
23+
};
24+
25+
// Limits for access order. Must be a permutation of {0, 1, 2}.
26+
template <auto Value>
27+
concept AccessOrderLimits = requires {
28+
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
29+
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
30+
(Value[2] >= 0 && Value[2] < 3));
31+
};
32+
33+
} // namespace ck_tile::builder
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <concepts>
7+
#include <type_traits>
8+
9+
#include "ck_tile/builder/conv_factory.hpp"
10+
#include "ck_tile/builder/versions.hpp"
11+
12+
namespace ck_tile::builder {
13+
14+
/**
15+
* @brief Top-level builder for creating convolution kernel instances.
16+
*
17+
* This struct serves as the main entry point for generating a convolution kernel.
18+
* It uses a factory pattern based on the provided signature, algorithm, and version
19+
* to construct the appropriate kernel instance.
20+
*
21+
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
22+
* the algorithm (e.g., data types, layouts, direction).
23+
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
24+
* @tparam VERSION The version of the builder implementation.
25+
*/
26+
template <ConvSignatureDescriptor auto SIGNATURE,
27+
ConvAlgorithmDescriptor auto ALGORITHM,
28+
StringLiteral VERSION = LATEST_API_VERSION>
29+
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
30+
struct ConvBuilder
31+
{
32+
static constexpr auto kVersion = VERSION;
33+
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
34+
// Output: The kernel class.
35+
using Instance = Factory::Instance;
36+
};
37+
38+
} // namespace ck_tile::builder

0 commit comments

Comments
 (0)