Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "grouped_convolution_forward_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"

template <template <typename PrecType> typename GemmConfig>
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
Expand All @@ -31,14 +31,14 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
Expand All @@ -50,9 +50,17 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])

int main(int argc, char* argv[])
{
try
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
Expand All @@ -25,23 +25,22 @@ struct GroupedConvolutionForwardInvoker

// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
ck_tile::sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>,
ConvConfig::PermuteA,
ConvConfig::PermuteB>;

constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr ck_tile::index_t NumGroupsToMerge = 1;
constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA;
constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB;
constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC;

constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ConvConfig::TileParitionerGroupNum,
ConvConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
Expand All @@ -51,22 +50,22 @@ struct GroupedConvolutionForwardInvoker
VectorSizeA,
VectorSizeB,
VectorSizeC,
NumGroupsToMerge,
ConvConfig::NumGroupsToMerge,
CDElementWise>;

using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ConvConfig::kPadM,
ConvConfig::kPadN,
ConvConfig::kPadK,
ConvConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
ConvConfig::TransposeC,
ConvConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
ConvConfig::NumWaveGroups,
ConvConfig::Preshuffle>;

using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
Expand All @@ -82,16 +81,16 @@ struct GroupedConvolutionForwardInvoker
VectorSizeB>;

using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;

const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());

const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
Expand All @@ -101,7 +100,7 @@ struct GroupedConvolutionForwardInvoker
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto scheduler = ConvConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;

using UniversalGemmProblem =
Expand All @@ -121,7 +120,7 @@ struct GroupedConvolutionForwardInvoker
VectorSizeB>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;

using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
Expand All @@ -134,12 +133,12 @@ struct GroupedConvolutionForwardInvoker
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvConfig::TransposeC,
memory_operation,
1,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#pragma once

template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType,
Expand All @@ -17,7 +17,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
int n_repeat)
{
float ave_time = Invoker::template grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
InDataType,
WeiDataType,
AccDataType,
Expand All @@ -39,7 +39,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<>& args,
}

template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename ConvConfig,
typename Invoker,
typename InDataType,
typename WeiDataType = InDataType,
Expand Down Expand Up @@ -141,7 +141,7 @@ int run_grouped_conv_fwd_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;

invoke_grouped_conv_fwd<NDimSpatial,
GemmConfig,
ConvConfig,
Invoker,
InDataType,
WeiDataType,
Expand Down Expand Up @@ -193,7 +193,7 @@ int run_grouped_conv_fwd_example_with_layouts(
}

template <typename Invoker,
typename GemmConfig,
typename ConvConfig,
typename InPrecType,
typename WeiPrecType = InPrecType,
typename OutPrecType = InPrecType>
Expand All @@ -215,7 +215,7 @@ int run_grouped_conv_fwd_example_prec_type(
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
Expand All @@ -225,7 +225,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
Expand All @@ -235,7 +235,7 @@ int run_grouped_conv_fwd_example_prec_type(
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
GemmConfig,
ConvConfig,
Invoker,
InPrecType,
WeiPrecType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,6 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
return false;
}

// TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock?
}

return true;
Expand Down
Loading