Skip to content

Commit 9249846

Browse files
authored
[CK_Builder] removed direction and elementwise_operation from required parameters … (#3192)
Removed direction and elementwise operation from default values required for convolution signature concept. Added constexpr helpers to set default values. Add compile-time tests.
1 parent 22a934a commit 9249846

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

experimental/builder/include/ck_tile/builder/conv_factory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
563563
SPATIAL_DIM,
564564
ConvDirection::FORWARD>());
565565
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
566-
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
566+
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
567567
using AlgorithmType = decltype(ALGORITHM);
568568

569569
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==

experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,41 @@ concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == D
4343
template <typename T>
4444
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
4545

46+
template <typename T>
47+
concept HasElementwiseOp = requires(T t) {
48+
{ t.elementwise_operation };
49+
};
50+
51+
template <typename T>
52+
concept HasConvolutionDirection = requires(T t) {
53+
{ t.direction };
54+
};
55+
56+
// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well
57+
// defined
58+
template <typename T>
59+
concept ElementwiseOpWellDefinedIfProvided = requires(T t) {
60+
requires !HasElementwiseOp<T> || requires {
61+
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
62+
};
63+
};
64+
65+
// Note: it is not required to provide a convolution, but if one is provided, check if well defined
66+
template <typename T>
67+
concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
68+
requires !HasConvolutionDirection<T> || requires {
69+
{ t.direction } -> std::convertible_to<ConvDirection>;
70+
};
71+
};
72+
4673
// Concept for a type that defines a convolution's operational signature.
4774
template <typename T>
4875
concept ConvSignatureDescriptor = requires(T t) {
4976
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
50-
{ t.direction } -> std::convertible_to<ConvDirection>;
5177
{ t.layout } -> ConvLayout;
5278
{ t.data_type } -> std::convertible_to<DataType>;
53-
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
79+
requires ElementwiseOpWellDefinedIfProvided<T>;
80+
requires ConvolutionDirectionWellDefinedIfProvided<T>;
5481
};
5582

5683
// Concept to validate a convolution signature's values.

experimental/builder/test/test_conv_description.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,37 @@ namespace ckt = ck_tile::test;
1919
// Defines the signature of the convolution operation to be tested.
2020
// This includes dimensionality, direction, data layout, and data type.
2121
struct ConvSignature
22+
{
23+
int spatial_dim = 2;
24+
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
25+
ckb::DataType data_type = ckb::DataType::FP16;
26+
ckb::GroupConvDeviceOp device_operation =
27+
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
28+
};
29+
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
30+
31+
// Compile time tests for concepts
32+
struct ConvSignatureWithOptionalParams
2233
{
2334
int spatial_dim = 2;
2435
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
2536
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
2637
ckb::DataType data_type = ckb::DataType::FP16;
2738
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
2839
};
29-
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
40+
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
41+
42+
struct ConvSignatureWithInvalidOptionalParams
43+
{
44+
int spatial_dim = 2;
45+
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
46+
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
47+
ckb::DataType data_type = ckb::DataType::FP16;
48+
int elementwise_operation = 7; // this should fail
49+
ckb::GroupConvDeviceOp device_operation =
50+
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
51+
};
52+
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
3053

3154
struct DefaultAlgorithm
3255
{

0 commit comments

Comments
 (0)