Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Optimization: Optimized TileShape Configuration for f8 #3617

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ __global__ void set_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -169,15 +169,15 @@ __global__ void set_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape(M, N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputA{},
{M, K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideOutput{},
{M, N, 1});
}
}
Expand Down Expand Up @@ -219,17 +219,17 @@ __global__ void set_dynamic_kernel_args_kernel(
GroupedGemmArgs::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmArgs::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmArgs::GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -244,15 +244,15 @@ __global__ void set_dynamic_kernel_args_kernel(
zero_start_index_M[group_index], N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputA{},
{zero_start_index_M[group_index], K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmArgs::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 128, 2, 1, 1, false>::StrideOutput{},
{zero_start_index_M[group_index], N, 1});
}
}
Expand Down Expand Up @@ -487,7 +487,7 @@ std::vector<at::Tensor> dispatch_fp8_grouped_kernel(
return f8f8bf16_grouped_impl<64, 128, 128, 2, 1, 1, true, FastAccum>(
xq_group, wq_group, scale, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_grouped_impl<128, 128, 128, 2, 1, 1, true, FastAccum>(
return f8f8bf16_grouped_impl<128, 256, 128, 2, 1, 1, false, FastAccum>(
xq_group, wq_group, scale, zero_start_index_M);
} else {
return f8f8bf16_grouped_impl<128, 128, 128, 1, 2, 1, true, FastAccum>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ at::Tensor f8f8bf16_rowwise_impl(
using EpilogueEVT =
cute::conditional_t<USE_BIAS, EVTComputeBias, EVTCompute1>;

using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using FastPongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
using FastAccum =
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using MainLoopSchedule =
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
Expand All @@ -185,21 +202,9 @@ at::Tensor f8f8bf16_rowwise_impl(
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecialized,
EpilogueSchedule,
EpilogueEVT>::CollectiveOp;

using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule =
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using FastPongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
using FastAccum =
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
using MainLoopSchedule =
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;

using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
Expand Down Expand Up @@ -331,20 +336,20 @@ at::Tensor dispatch_fp8_rowwise_kernel(
2,
1,
1,
false,
true,
FastAccum,
UseBias,
InputDType,
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_rowwise_impl<
128,
128,
256,
128,
2,
1,
1,
true,
false,
FastAccum,
UseBias,
InputDType,
Expand All @@ -354,10 +359,10 @@ at::Tensor dispatch_fp8_rowwise_kernel(
128,
128,
128,
1,
2,
1,
false,
1,
true,
FastAccum,
UseBias,
InputDType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,22 @@ at::Tensor f8f8bf16_tensorwise_impl(
KernelScheduleAuto; // Kernel to launch based on the default setting in
// the Collective Builder

using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using FastDefaultSchedule =
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
using FastPongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
using FastAccum =
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using MainLoopSchedule =
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using Scale_ =
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementComputeEpilogue>;
Expand Down Expand Up @@ -140,7 +145,7 @@ at::Tensor f8f8bf16_tensorwise_impl(
ElementOutput,
LayoutOutput,
AlignmentOutput,
cutlass::epilogue::TmaWarpSpecialized,
EpilogueSchedule,
EpilogueEVT>::CollectiveOp;

using CollectiveMainloop =
Expand Down Expand Up @@ -239,10 +244,10 @@ at::Tensor f8f8bf16_tensorwise(
return f8f8bf16_tensorwise_impl<64, 128, 128, 2, 1, 1, true, true>(
XQ, WQ, scale);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_tensorwise_impl<128, 128, 128, 2, 1, 1, true, true>(
return f8f8bf16_tensorwise_impl<128, 256, 128, 2, 1, 1, false, true>(
XQ, WQ, scale);
} else {
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, false, true>(
return f8f8bf16_tensorwise_impl<128, 128, 128, 1, 2, 1, true, true>(
XQ, WQ, scale);
}
}
Expand Down
Loading