Skip to content

kernel: support fp8 kv cache #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/kernels/attention/attention_bench_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ void attention_bench_sm80(nvbench::state& state) {

state.exec([&](nvbench::launch& launch) {
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params,
launch.get_stream());
run_attention_kernel_sm80<cute::half_t, cute::half_t, HEAD_DIM>(
params, launch.get_stream());
});
});
}
Expand Down
82 changes: 47 additions & 35 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "cute/config.hpp"
#include "cute_extensions.cuh"
#include "fast_cast.cuh"
#include "layout_conformance.cuh"
#include "mask.h"
#include "online_softmax.cuh"
#include "ptx.cuh"
Expand Down Expand Up @@ -38,6 +39,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// type alias
using DType = typename Traits::DType;
using KV_DType = typename Traits::KV_DType;

using TiledMma = typename Traits::TiledMma;
using Layout = typename Traits::LayoutConvertor;
Expand Down Expand Up @@ -69,10 +71,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto [Q, O] = tile.template get_qo_tile<DType>(batch_idx, head_idx);
// (kv_len, HEAD_DIM)
auto [K, V] =
tile.template get_kv_tile<DType>(batch_idx, head_idx / group_size);
tile.template get_kv_tile<KV_DType>(batch_idx, head_idx / group_size);

const int q_len = size<0>(Q.shape());
const int kv_len = size<0>(K.shape());
const int q_len = size<0>(Q);
const int kv_len = size<0>(K);

if (m_block * kBlockM >= q_len) {
// out of bound, return
Expand Down Expand Up @@ -110,8 +112,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
// Smem
extern __shared__ char smem[];
DType* q_smem = (DType*)smem;
DType* k_smem = q_smem + cosize(SmemLayoutQ{});
DType* v_smem = k_smem + cosize(SmemLayoutK{});
KV_DType* k_smem = q_smem + cosize(SmemLayoutQ{});
KV_DType* v_smem = k_smem + cosize(SmemLayoutK{});

// (BLK_M, BLK_K), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
Expand Down Expand Up @@ -141,10 +143,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_q = [&]() {
auto tQgQ = gmem_thr_copy_Q.partition_S(gQ);
auto tQsQ = gmem_thr_copy_Q.partition_D(sQ);
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(
gmem_tiled_copy_Q,
tQgQ,
tQsQ,
Expand All @@ -157,37 +159,36 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni));
// skip zero fill oob for k since mask will mask out oob with -inf
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/true>(gmem_tiled_copy_Q,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

Tensor tVsV = gmem_thr_copy_KV.partition_D(sV);
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni));
// TODO: skip zero fill oob for v, may have nan issue
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_Q,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/true,
/*ZFILL_K=*/true>(gmem_tiled_copy_Q,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tidx);
// GEMM-I: S = [email protected]
auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
// alocate register with type KV_DType for K
auto tSrK = make_fragment_B<KV_DType>(thr_mma, sK); // (MMA,MMA_N,MMA_K)

// s2r tiled copy for qkv
SmemTiledCopyQ smem_tiled_copy_Q;
Expand Down Expand Up @@ -219,12 +220,18 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
tSsK(_, _, next_ki),
tSrK_copy_view(_, _, next_ki));
}
cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);

if constexpr (!is_same_v<DType, KV_DType>) {
// fragment layout swizzle between threads
frag_B_layout_swizzle(tSrK_copy_view, tidx);
}
mixed_gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS);
}
};

// GEMM-II: O = softmax(S)@V
auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N)
// alocate register with type KV_DType for V^t
auto tOrVt = make_fragment_B<KV_DType>(thr_mma, sVt); // (MMA,MMA_K,MMA_N)

SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx);
Expand Down Expand Up @@ -254,7 +261,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
tOsVt(_, _, next_ki),
tOrVt_copy_view(_, _, next_ki));
}
cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);

if constexpr (!is_same_v<DType, KV_DType>) {
// fragment layout swizzle between threads
frag_B_trans_layout_swizzle(tOrVt_copy_view, tidx);
}
mixed_gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO);
}
};

Expand Down Expand Up @@ -288,10 +300,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) {

// wait for smem copy done before gmem copy
__syncthreads();
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/false>(
safe_copy</*EVEN_MN=*/false,
EVEN_K,
/*ZFILL_MN=*/false,
/*ZFILL_K=*/false>(
gmem_tiled_copy_O,
tOsO,
tOgO,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ torch::Tensor attention_pagedkv_sm80(

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
run_attention_kernel_sm80<DTYPE, DTYPE, HEAD_DIM>(params);
});
});
return out;
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/attention/attention_kernel_sm80_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ torch::Tensor attention_sm80(

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
run_attention_kernel_sm80<DTYPE, DTYPE, HEAD_DIM>(params);
});
});
return out;
Expand Down
2 changes: 1 addition & 1 deletion src/kernels/attention/attention_kernel_sm80_varlen_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ torch::Tensor attention_varlen_sm80(

DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
run_attention_kernel_sm80<DTYPE, DTYPE, HEAD_DIM>(params);
});
});
return out;
Expand Down
17 changes: 11 additions & 6 deletions src/kernels/attention/attention_launch_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,48 @@ void run_attention_kernel(const Params& params, cudaStream_t stream) {
} // namespace detail

// user-facing function to run the attention kernel
template <typename Element, int HEAD_DIM, typename Params>
template <typename DTYPE, typename KV_DTYPE, int HEAD_DIM, typename Params>
void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
// normalize params that for performance optimization
params.normalize();

// TODO: tune block shape MNK based on the head dim and smem size
if constexpr (HEAD_DIM == 64) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<DTYPE,
KV_DTYPE,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 96) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<DTYPE,
KV_DTYPE,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/32>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 128) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<DTYPE,
KV_DTYPE,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else if constexpr (HEAD_DIM == 256) {
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<DTYPE,
KV_DTYPE,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
/*BLK_K=*/64>;
detail::run_attention_kernel<Traits>(params, stream);
} else {
// use the default block size
using Traits = AttentionTraitsSM80<Element,
using Traits = AttentionTraitsSM80<DTYPE,
KV_DTYPE,
HEAD_DIM,
/*BLK_M=*/64,
/*BLK_N=*/64,
Expand Down
Loading
Loading