From 8efeb7f5a372e916fe33a1aa176ae63121955eef Mon Sep 17 00:00:00 2001 From: skrider Date: Thu, 8 Feb 2024 21:59:54 +0000 Subject: [PATCH 01/19] add print statements for debugging --- csrc/flash_attn/src/debug.h | 39 ++++++++++++++++ csrc/flash_attn/src/flash_fwd_kernel.h | 64 ++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 csrc/flash_attn/src/debug.h diff --git a/csrc/flash_attn/src/debug.h b/csrc/flash_attn/src/debug.h new file mode 100644 index 000000000..2870f6198 --- /dev/null +++ b/csrc/flash_attn/src/debug.h @@ -0,0 +1,39 @@ +#include + +#define KIN_PRINT(tag, statement) \ + if (cute::thread0()) { \ + printf("[kin:start:%s]\n", tag); \ + statement; \ + printf("\n[kin:end:%s]\n", tag); \ + } + +template +void +print_traits() { + // bool + printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ); + printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ); + + // int + printf("Kernel_traits::kNWarps : %s\n", Kernel_traits::kNWarps ); + printf("Kernel_traits::kNThreads : %s\n", Kernel_traits::kNThreads ); + printf("Kernel_traits::kBlockM : %s\n", Kernel_traits::kBlockM ); + printf("Kernel_traits::kBlockN : %s\n", Kernel_traits::kBlockN ); + printf("Kernel_traits::kHeadDim : %s\n", Kernel_traits::kHeadDim ); + printf("Kernel_traits::kBlockKSmem : %s\n", Kernel_traits::kBlockKSmem ); + printf("Kernel_traits::kBlockKGmem : %s\n", Kernel_traits::kBlockKGmem ); + printf("Kernel_traits::kSwizzle : %s\n", Kernel_traits::kSwizzle ); + printf("Kernel_traits::kSmemQSize : %s\n", Kernel_traits::kSmemQSize ); + printf("Kernel_traits::kSmemKVSize : %s\n", Kernel_traits::kSmemKVSize ); + printf("Kernel_traits::kSmemSize : %s\n", Kernel_traits::kSmemSize ); + printf("Kernel_traits::kGmemElemsPerLoad : %s\n", Kernel_traits::kGmemElemsPerLoad ); + + // cute object + printf("Kernel_traits::GmemLayoutAtom : "); print(Kernel_traits::GmemLayoutAtom); printf("\n"); + printf("Kernel_traits::GmemTiledCopyQKV : "); print(Kernel_traits::GmemTiledCopyQKV); printf("\n"); + printf("Kernel_traits::GmemTiledCopyO : "); print(Kernel_traits::GmemTiledCopyO); printf("\n"); + printf("Kernel_traits::SmemCopyAtom : "); print(Kernel_traits::SmemCopyAtom); printf("\n"); + printf("Kernel_traits::SmemCopyAtomTransposed : "); print(Kernel_traits::SmemCopyAtomTransposed); printf("\n"); + printf("Kernel_traits::MMA_Atom_Arch : "); print(Kernel_traits::MMA_Atom_Arch); printf("\n"); +} + diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bd29d5670..62e07ef59 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -18,6 +18,8 @@ #include "dropout.h" #include "rotary.h" +#include "debug.h" + namespace flash { using namespace cute; @@ -41,6 +43,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; +#if 0 + KIN_PRINT("Kernel_traits", print_traits()); +#endif auto seed_offset = at::cuda::philox::unpack(params.philox_args); flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, @@ -55,6 +60,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; +#if 0 + // const int sum_s_q; + // const int sum_s_k; + // const int actual_seqlen_q; + // const int seqlen_k_cache; + // const int actual_seqlen_k; + KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q)) + KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k)) + KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q)) + KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache)) + KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k)) +#endif const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); @@ -136,10 +153,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); + +#if 1 + KIN_PRINT("sK.layout()", print(sK.layout())) + KIN_PRINT("gK.layout()", print(gK.layout())) + KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem)) +#endif + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); +#if 1 + KIN_PRINT("sV.layout()", print(sV.layout())) + KIN_PRINT("sVt.layout()", print(sVt.layout())) + KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) + KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem)) +#endif + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -150,16 +181,30 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); +#if 1 + KIN_PRINT("tKgK.layout()", print(tKgK.layout())) + KIN_PRINT("tKsK.layout()", print(tKsK.layout())) +#endif + typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) +#if 1 + KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) + KIN_PRINT("tSrK.layout()", print(tSrK.layout())) +#endif + Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K +#if 1 + KIN_PRINT("acc_o.layout()", print(acc_o.layout())) +#endif + // // Copy Atom retiling // @@ -168,11 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); +#if 1 + KIN_PRINT("smem_thr_copy_Q.print_all()", smem_thr_copy_Q.print_all()) + KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) +#endif // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); +# if 1 + KIN_PRINT("tSsK.layout()", print(tSsK.layout())) +#endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); @@ -189,6 +241,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) +#if 1 + KIN_PRINT("cQ.layout()", print(cQ.layout())) + KIN_PRINT("cKV.layout()", print(cKV.layout())) +#endif // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); @@ -205,10 +261,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) +#if 1 + KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) + KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) +#endif // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); +#if 1 + KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) + KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) +#endif // Set predicates for k bounds if (!Is_even_K) { From ac5e78a6e6c48610de9954052695682c425faa47 Mon Sep 17 00:00:00 2001 From: skrider Date: Fri, 9 Feb 2024 00:05:39 +0000 Subject: [PATCH 02/19] add print statements for debugging --- csrc/flash_attn/src/debug.h | 61 +++++++++++-------- csrc/flash_attn/src/flash_fwd_kernel.h | 82 ++++++++++++++++++-------- 2 files changed, 92 insertions(+), 51 deletions(-) diff --git a/csrc/flash_attn/src/debug.h b/csrc/flash_attn/src/debug.h index 2870f6198..0d8250bd4 100644 --- a/csrc/flash_attn/src/debug.h +++ b/csrc/flash_attn/src/debug.h @@ -1,39 +1,50 @@ #include +#include "block_info.h" + +#pragma once #define KIN_PRINT(tag, statement) \ - if (cute::thread0()) { \ - printf("[kin:start:%s]\n", tag); \ + if (thread0()) { \ + printf("\n[kin:start:%s]\n", tag); \ statement; \ printf("\n[kin:end:%s]\n", tag); \ } +#define KIN_PRINT_BOOL(tag, BOOL) \ + if (thread0()) { \ + printf("\n[kin:start:%s]\n", tag); \ + printf("%s", BOOL ? "true" : "false"); \ + printf("\n[kin:end:%s]\n", tag); \ + } + template -void +__forceinline__ __device__ void print_traits() { // bool - printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ); - printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ); + printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false"); + printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false"); // int - printf("Kernel_traits::kNWarps : %s\n", Kernel_traits::kNWarps ); - printf("Kernel_traits::kNThreads : %s\n", Kernel_traits::kNThreads ); - printf("Kernel_traits::kBlockM : %s\n", Kernel_traits::kBlockM ); - printf("Kernel_traits::kBlockN : %s\n", Kernel_traits::kBlockN ); - printf("Kernel_traits::kHeadDim : %s\n", Kernel_traits::kHeadDim ); - printf("Kernel_traits::kBlockKSmem : %s\n", Kernel_traits::kBlockKSmem ); - printf("Kernel_traits::kBlockKGmem : %s\n", Kernel_traits::kBlockKGmem ); - printf("Kernel_traits::kSwizzle : %s\n", Kernel_traits::kSwizzle ); - printf("Kernel_traits::kSmemQSize : %s\n", Kernel_traits::kSmemQSize ); - printf("Kernel_traits::kSmemKVSize : %s\n", Kernel_traits::kSmemKVSize ); - printf("Kernel_traits::kSmemSize : %s\n", Kernel_traits::kSmemSize ); - printf("Kernel_traits::kGmemElemsPerLoad : %s\n", Kernel_traits::kGmemElemsPerLoad ); - - // cute object - printf("Kernel_traits::GmemLayoutAtom : "); print(Kernel_traits::GmemLayoutAtom); printf("\n"); - printf("Kernel_traits::GmemTiledCopyQKV : "); print(Kernel_traits::GmemTiledCopyQKV); printf("\n"); - printf("Kernel_traits::GmemTiledCopyO : "); print(Kernel_traits::GmemTiledCopyO); printf("\n"); - printf("Kernel_traits::SmemCopyAtom : "); print(Kernel_traits::SmemCopyAtom); printf("\n"); - printf("Kernel_traits::SmemCopyAtomTransposed : "); print(Kernel_traits::SmemCopyAtomTransposed); printf("\n"); - printf("Kernel_traits::MMA_Atom_Arch : "); print(Kernel_traits::MMA_Atom_Arch); printf("\n"); + printf("Kernel_traits::kNWarps : %d\n", Kernel_traits::kNWarps ); + printf("Kernel_traits::kNThreads : %d\n", Kernel_traits::kNThreads ); + printf("Kernel_traits::kBlockM : %d\n", Kernel_traits::kBlockM ); + printf("Kernel_traits::kBlockN : %d\n", Kernel_traits::kBlockN ); + printf("Kernel_traits::kHeadDim : %d\n", Kernel_traits::kHeadDim ); + printf("Kernel_traits::kBlockKSmem : %d\n", Kernel_traits::kBlockKSmem ); + printf("Kernel_traits::kBlockKGmem : %d\n", Kernel_traits::kBlockKGmem ); + printf("Kernel_traits::kSwizzle : %d\n", Kernel_traits::kSwizzle ); + printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize ); + printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize ); + printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize ); + printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad ); } +template +__forceinline__ __device__ void +print_binfo(const BlockInfo& binfo) { + printf("binfo.sum_s_q : %d\n", binfo.sum_s_q); + printf("binfo.sum_s_k : %d\n", binfo.sum_s_k); + printf("binfo.actual_seqlen_q : %d\n", binfo.actual_seqlen_q); + printf("binfo.seqlen_k_cache : %d\n", binfo.seqlen_k_cache); + printf("binfo.actual_seqlen_k : %d\n", binfo.actual_seqlen_k); +} diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 62e07ef59..5a19c710a 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -43,7 +43,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; -#if 0 +#if 1 KIN_PRINT("Kernel_traits", print_traits()); #endif @@ -60,17 +60,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; -#if 0 - // const int sum_s_q; - // const int sum_s_k; - // const int actual_seqlen_q; - // const int seqlen_k_cache; - // const int actual_seqlen_k; - KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q)) - KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k)) - KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q)) - KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache)) - KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k)) +#if 1 + KIN_PRINT("binfo", print_binfo(binfo)) #endif const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); @@ -153,22 +144,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); - #if 1 KIN_PRINT("sK.layout()", print(sK.layout())) KIN_PRINT("gK.layout()", print(gK.layout())) - KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem)) #endif Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - #if 1 KIN_PRINT("sV.layout()", print(sV.layout())) KIN_PRINT("sVt.layout()", print(sVt.layout())) KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) - KIN_PRINT("Share_Q_K_smem", printf("%d", Kernel_traits::Share_Q_K_smem)) #endif typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; @@ -180,7 +167,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - #if 1 KIN_PRINT("tKgK.layout()", print(tKgK.layout())) KIN_PRINT("tKsK.layout()", print(tKsK.layout())) @@ -191,7 +177,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - #if 1 KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) KIN_PRINT("tSrK.layout()", print(tSrK.layout())) @@ -200,7 +185,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - #if 1 KIN_PRINT("acc_o.layout()", print(acc_o.layout())) #endif @@ -211,10 +195,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); +#if 0 + KIN_PRINT("fail", smem_thr_copy_Q.print_all()); +#endif // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); #if 1 - KIN_PRINT("smem_thr_copy_Q.print_all()", smem_thr_copy_Q.print_all()) KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) #endif // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} @@ -222,7 +208,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); -# if 1 +#if 1 KIN_PRINT("tSsK.layout()", print(tSsK.layout())) #endif @@ -261,15 +247,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) -#if 1 - KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) - KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) -#endif // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); #if 1 + KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) + KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) #endif @@ -552,6 +536,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; +#if 1 + KIN_PRINT("Kernel_traits", print_traits()) + KIN_PRINT_BOOL("Is_causal", Is_causal) + KIN_PRINT_BOOL("Is_local", Is_local) + KIN_PRINT_BOOL("Has_alibi", Has_alibi) + KIN_PRINT_BOOL("Is_even_MN", Is_even_MN) + KIN_PRINT_BOOL("Is_even_K", Is_even_K) + KIN_PRINT_BOOL("Split", Split) + KIN_PRINT_BOOL("Append_KV", Append_KV) +#endif using GmemTiledCopyO = std::conditional_t< !Split, @@ -564,6 +558,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; +#if 1 + KIN_PRINT("binfo", print_binfo(binfo)) +#endif const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_block_min = !Is_local @@ -645,13 +642,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); +#if 1 + KIN_PRINT("sK.layout()", print(sK.layout())) + KIN_PRINT("gK.layout()", print(gK.layout())) + KIN_PRINT("sV.layout()", print(sV.layout())) + KIN_PRINT("sVt.layout()", print(sVt.layout())) + KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) +#endif typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -662,14 +665,25 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); +#if 1 + KIN_PRINT("tKgK.layout()", print(tKgK.layout())) + KIN_PRINT("tKsK.layout()", print(tKsK.layout())) +#endif typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) +#if 1 + KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) + KIN_PRINT("tSrK.layout()", print(tSrK.layout())) +#endif Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K +#if 1 + KIN_PRINT("acc_o.layout()", print(acc_o.layout())) +#endif // // Copy Atom retiling @@ -678,10 +692,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); +#if 1 + KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) +#endif auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); +#if 1 + KIN_PRINT("tSsK.layout()", print(tSsK.layout())) +#endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); @@ -697,6 +717,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) +#if 1 + KIN_PRINT("cQ.layout()", print(cQ.layout())) + KIN_PRINT("cKV.layout()", print(cKV.layout())) +#endif // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) @@ -705,6 +729,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); +#if 1 + KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) + KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) + KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) + KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) +#endif // Set predicates for k bounds if (!Is_even_K) { From 14b190bc2cd595c7b238f37089a7f9a4fc7ace64 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 00:42:47 +0000 Subject: [PATCH 03/19] reshape gmem copy --- csrc/flash_attn/src/debug.h | 22 ++-- csrc/flash_attn/src/flash_fwd_kernel.h | 143 +++++++++++++------------ csrc/flash_attn/src/kernel_traits.h | 11 ++ 3 files changed, 103 insertions(+), 73 deletions(-) diff --git a/csrc/flash_attn/src/debug.h b/csrc/flash_attn/src/debug.h index 0d8250bd4..e85437be4 100644 --- a/csrc/flash_attn/src/debug.h +++ b/csrc/flash_attn/src/debug.h @@ -3,18 +3,18 @@ #pragma once -#define KIN_PRINT(tag, statement) \ +#define KIN_PRINT(statement) \ if (thread0()) { \ - printf("\n[kin:start:%s]\n", tag); \ + printf("\n[kin:start:%s]\n", #statement); \ statement; \ - printf("\n[kin:end:%s]\n", tag); \ + printf("\n[kin:end:%s]\n", #statement); \ } -#define KIN_PRINT_BOOL(tag, BOOL) \ +#define KIN_PRINT_BOOL(BOOL) \ if (thread0()) { \ - printf("\n[kin:start:%s]\n", tag); \ + printf("\n[kin:start:%s]\n", #BOOL); \ printf("%s", BOOL ? "true" : "false"); \ - printf("\n[kin:end:%s]\n", tag); \ + printf("\n[kin:end:%s]\n", #BOOL); \ } template @@ -36,7 +36,17 @@ print_traits() { printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize ); printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize ); printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize ); + printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread ); printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad ); + + // cute object + printf("Kernel_traits::GmemLayoutAtom : "); + cute::print(Kernel_traits::GmemLayoutAtom()); + printf("\n"); + printf("Kernel_traits::GmemTiledCopyQKV :\n"); + cute::print(Kernel_traits::GmemTiledCopyQKV()); + printf("\n"); + } template diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 5a19c710a..35dbc4f33 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -44,7 +44,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; #if 1 - KIN_PRINT("Kernel_traits", print_traits()); + KIN_PRINT(print_traits()); #endif auto seed_offset = at::cuda::philox::unpack(params.philox_args); @@ -61,7 +61,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; #if 1 - KIN_PRINT("binfo", print_binfo(binfo)) + KIN_PRINT(print_binfo(binfo)) #endif const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); @@ -145,17 +145,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); #if 1 - KIN_PRINT("sK.layout()", print(sK.layout())) - KIN_PRINT("gK.layout()", print(gK.layout())) + KIN_PRINT(print(sK.layout())) + KIN_PRINT(print(gK.layout())) #endif Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); #if 1 - KIN_PRINT("sV.layout()", print(sV.layout())) - KIN_PRINT("sVt.layout()", print(sVt.layout())) - KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) + KIN_PRINT(print(sV.layout())) + KIN_PRINT(print(sVt.layout())) + KIN_PRINT(print(sVtNoSwizzle.layout())) #endif typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; @@ -168,8 +168,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); #if 1 - KIN_PRINT("tKgK.layout()", print(tKgK.layout())) - KIN_PRINT("tKsK.layout()", print(tKsK.layout())) + KIN_PRINT(print(tKgK.layout())) + KIN_PRINT(print(tKsK.layout())) #endif typename Kernel_traits::TiledMma tiled_mma; @@ -178,15 +178,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) #if 1 - KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) - KIN_PRINT("tSrK.layout()", print(tSrK.layout())) + KIN_PRINT(print(tSrQ.layout())) + KIN_PRINT(print(tSrK.layout())) #endif Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K #if 1 - KIN_PRINT("acc_o.layout()", print(acc_o.layout())) + KIN_PRINT(print(acc_o.layout())) #endif // @@ -196,12 +196,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); #if 0 - KIN_PRINT("fail", smem_thr_copy_Q.print_all()); + KIN_PRINT(smem_thr_copy_Q.print_all()); #endif // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); #if 1 - KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) + KIN_PRINT(print(tSsQ.layout())) #endif // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} @@ -209,7 +209,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); #if 1 - KIN_PRINT("tSsK.layout()", print(tSsK.layout())) + KIN_PRINT(print(tSsK.layout())) #endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); @@ -228,8 +228,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) #if 1 - KIN_PRINT("cQ.layout()", print(cQ.layout())) - KIN_PRINT("cKV.layout()", print(cKV.layout())) + KIN_PRINT(print(cQ.layout())) + KIN_PRINT(print(cKV.layout())) #endif // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { @@ -252,10 +252,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); #if 1 - KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) - KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) - KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) - KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) + KIN_PRINT(print(tQcQ.layout())) + KIN_PRINT(print(tKVcKV.layout())) + KIN_PRINT(print(tQpQ.layout())) + KIN_PRINT(print(tKVpKV.layout())) #endif // Set predicates for k bounds @@ -537,14 +537,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; #if 1 - KIN_PRINT("Kernel_traits", print_traits()) - KIN_PRINT_BOOL("Is_causal", Is_causal) - KIN_PRINT_BOOL("Is_local", Is_local) - KIN_PRINT_BOOL("Has_alibi", Has_alibi) - KIN_PRINT_BOOL("Is_even_MN", Is_even_MN) - KIN_PRINT_BOOL("Is_even_K", Is_even_K) - KIN_PRINT_BOOL("Split", Split) - KIN_PRINT_BOOL("Append_KV", Append_KV) + KIN_PRINT(print_traits()) + KIN_PRINT_BOOL(Is_causal) + KIN_PRINT_BOOL(Is_local) + KIN_PRINT_BOOL(Has_alibi) + KIN_PRINT_BOOL(Is_even_MN) + KIN_PRINT_BOOL(Is_even_K) + KIN_PRINT_BOOL(Split) + KIN_PRINT_BOOL(Append_KV) #endif using GmemTiledCopyO = std::conditional_t< @@ -559,7 +559,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; #if 1 - KIN_PRINT("binfo", print_binfo(binfo)) + KIN_PRINT(print_binfo(binfo)) #endif const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; @@ -649,25 +649,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); #if 1 - KIN_PRINT("sK.layout()", print(sK.layout())) - KIN_PRINT("gK.layout()", print(gK.layout())) - KIN_PRINT("sV.layout()", print(sV.layout())) - KIN_PRINT("sVt.layout()", print(sVt.layout())) - KIN_PRINT("sVtNoSwizzle.layout()", print(sVtNoSwizzle.layout())) + KIN_PRINT(print(sK.layout())) + KIN_PRINT(print(gK.layout())) + KIN_PRINT(print(sV.layout())) + KIN_PRINT(print(sVt.layout())) + KIN_PRINT(print(sVtNoSwizzle.layout())) #endif - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); +#if 1 + KIN_PRINT(print(tKgK.layout())) + KIN_PRINT(print(tKsK.layout())) +#endif - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); #if 1 - KIN_PRINT("tKgK.layout()", print(tKgK.layout())) - KIN_PRINT("tKsK.layout()", print(tKsK.layout())) + fill(tVgV, 1.f * ((Element) tidx)); + __syncthreads(); + + KIN_PRINT(print_tensor(gV)) #endif typename Kernel_traits::TiledMma tiled_mma; @@ -676,13 +685,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) #if 1 - KIN_PRINT("tSrQ.layout()", print(tSrQ.layout())) - KIN_PRINT("tSrK.layout()", print(tSrK.layout())) + KIN_PRINT(print(tSrQ.layout())) + KIN_PRINT(print(tSrK.layout())) #endif Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K #if 1 - KIN_PRINT("acc_o.layout()", print(acc_o.layout())) + KIN_PRINT(print(acc_o.layout())) #endif // @@ -693,14 +702,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); #if 1 - KIN_PRINT("tSsQ.layout()", print(tSsQ.layout())) + KIN_PRINT(print(tSsQ.layout())) #endif auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); #if 1 - KIN_PRINT("tSsK.layout()", print(tSsK.layout())) + KIN_PRINT(print(tSsK.layout())) #endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); @@ -718,22 +727,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) #if 1 - KIN_PRINT("cQ.layout()", print(cQ.layout())) - KIN_PRINT("cKV.layout()", print(cKV.layout())) + KIN_PRINT(print(cQ.layout())) + KIN_PRINT(print(cKV.layout())) #endif // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); #if 1 - KIN_PRINT("tQcQ.layout()", print(tQcQ.layout())) - KIN_PRINT("tKVcKV.layout()", print(tKVcKV.layout())) - KIN_PRINT("tQpQ.layout()", print(tQpQ.layout())) - KIN_PRINT("tKVpKV.layout()", print(tKVpKV.layout())) + KIN_PRINT(print(tQcQ.layout())) + KIN_PRINT(print(tKVcKV.layout())) + KIN_PRINT(print(tQpQ.layout())) + KIN_PRINT(print(tKVpKV.layout())) #endif // Set predicates for k bounds @@ -792,8 +801,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tKgKnew = gmem_thr_copy_KV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_KV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); @@ -853,7 +862,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); @@ -890,7 +899,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); @@ -935,11 +944,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1013,7 +1022,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -1034,7 +1043,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index a7a5cf1ed..5d6cab9d9 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + + // from how many rows does each thread have to fetch + static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); + // Here we assign a contiguous tile to each thread, rather than a 1x8 row every + // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread + // do not cross a page boundary. This way, each thread need only fetch 1 page index per + // mainloop iteration. R>udimentary testing shows no slowdown. + using GmemTiledCopyQKVPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout, _8>, Stride<_8, _1>>{})); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, From 409431b812bf5ef85096232552735a8ae5b2f87c Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 01:43:27 +0000 Subject: [PATCH 04/19] only test trivial block size --- csrc/cutlass | 2 +- tests/test_flash_attn.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index bbe579a9e..751eb9a88 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 +Subproject commit 751eb9a8859ac36bfc77551f9e4a957c31a5a8b1 diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 308e30bec..05f94490b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1832,11 +1832,10 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) -# @pytest.mark.parametrize("paged_kv_block_size", [256]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [256]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) From 70dd04996aa874b2ac9e40bc81461a44d17a1729 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 02:41:44 +0000 Subject: [PATCH 05/19] implement kv page iteration functions --- csrc/flash_attn/src/flash_fwd_kernel.h | 49 ++++++++++++-------------- csrc/flash_attn/src/utils.h | 49 ++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 27 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 35dbc4f33..a14c7b558 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -621,16 +621,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + : init_thread_kv_page_slice_offset(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, + params.k_batch_stride, params.k_row_stride, params.k_head_stride); const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + : init_thread_kv_page_slice_offset(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, + params.v_batch_stride, params.v_row_stride, params.v_head_stride); Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -842,14 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + // const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + // const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + // const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + // const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + // const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + // const int offset_diff = block_table_offset_next - block_table_offset_cur; + // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + params.v_batch_stride, params.v_row_stride); + tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + params.k_batch_stride, params.k_row_stride); } } } @@ -973,11 +977,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization @@ -1016,11 +1017,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -1037,11 +1035,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 2b45e87b2..36ab3925b 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -292,6 +292,55 @@ void cp_async_wait() { //////////////////////////////////////////////////////////////////////////////////////////////////// +// resolves initial base address of a slice of a paged kv copy from gmem +template +__forceinline__ __device__ +int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n_block_max, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride, const int head_stride) { + // base col of thread's slice relative to the block + const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad; + // base row of thread's slice relative to the block + const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; + // base col of thread's slice relative to the entire tensor + const int global_row_offset = block_row_offset + (n_block_max - 1) * Kernel_traits::kBlockN; + // base row of thread's slice relative to the page + const int page_offset = global_row_offset % page_block_size; + + const int virtual_page_idx = global_row_offset / page_block_size; + + return block_table[virtual_page_idx] * page_stride + + page_offset * row_stride + + hidx * head_stride + + col_offset; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// advances base address of a slice of a paged copy from gmem +template +__forceinline__ __device__ +int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { + // base row of thread's slice relative to the block + const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; + // base col of thread's slice relative to the entire tensor + const int global_row_offset_cur = block_row_offset + n_block * Kernel_traits::kBlockN; + const int global_row_offset_next = block_row_offset + (n_block - 1) * Kernel_traits::kBlockN; + // base row of thread's slice relative to the page + const int page_offset_cur = global_row_offset_cur % page_block_size; + const int page_offset_next = global_row_offset_next % page_block_size; + + const int virtual_page_idx_cur = global_row_offset_cur / page_block_size; + const int virtual_page_idx_next = global_row_offset_next / page_block_size; + + const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur]; + const int offset_diff = page_offset_next - page_offset_cur; + + return table_diff * page_stride + offset_diff * row_stride; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template From 59e76bea7661231b3d374d2699582041cd092e21 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 03:12:55 +0000 Subject: [PATCH 06/19] rearrange initial offset computation --- csrc/flash_attn/src/flash_fwd_kernel.h | 24 +++++++++++++++--------- csrc/flash_attn/src/utils.h | 9 +++++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index a14c7b558..f2b28186c 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : init_thread_kv_page_slice_offset(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, - params.k_batch_stride, params.k_row_stride, params.k_head_stride); + : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : init_thread_kv_page_slice_offset(tidx, bidh / params.h_h_k_ratio, n_block_max, params.page_block_size, block_table, - params.v_batch_stride, params.v_row_stride, params.v_head_stride); + : (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); + + if (block_table != nullptr) { + tKgK.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } + #if 1 KIN_PRINT(print(tKgK.layout())) KIN_PRINT(print(tKsK.layout())) @@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // const int offset_diff = block_table_offset_next - block_table_offset_cur; // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; - tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); - tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } } @@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 36ab3925b..48d460e2b 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -292,11 +292,12 @@ void cp_async_wait() { //////////////////////////////////////////////////////////////////////////////////////////////////// -// resolves initial base address of a slice of a paged kv copy from gmem +// resolves initial base offset of a slice of a paged kv copy from gmem. +// assumes that the tensor has already been positioned at the correct head. template __forceinline__ __device__ -int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n_block_max, const int page_block_size, - const int* block_table, const int page_stride, const int row_stride, const int head_stride) { +int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { // base col of thread's slice relative to the block const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad; // base row of thread's slice relative to the block @@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n return block_table[virtual_page_idx] * page_stride + page_offset * row_stride - + hidx * head_stride + col_offset; } @@ -321,6 +321,7 @@ template __forceinline__ __device__ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { + return 0; // base row of thread's slice relative to the block const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; // base col of thread's slice relative to the entire tensor From 175369fd425034d8a77070e5953f76a1086d5431 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 07:43:15 +0000 Subject: [PATCH 07/19] tests passing for single page k --- csrc/flash_attn/src/debug.h | 87 +++++++++++++++++- csrc/flash_attn/src/flash_fwd_kernel.h | 121 +++++-------------------- csrc/flash_attn/src/utils.h | 19 ++-- 3 files changed, 119 insertions(+), 108 deletions(-) diff --git a/csrc/flash_attn/src/debug.h b/csrc/flash_attn/src/debug.h index e85437be4..ed7f011b4 100644 --- a/csrc/flash_attn/src/debug.h +++ b/csrc/flash_attn/src/debug.h @@ -17,9 +17,89 @@ printf("\n[kin:end:%s]\n", #BOOL); \ } +__forceinline__ __device__ +void print_qkv_params(const Qkv_params& params) { + // LLM generated + printf("Qkv_params:\n"); + printf("q_ptr: %p\n", params.q_ptr); + printf("k_ptr: %p\n", params.k_ptr); + printf("v_ptr: %p\n", params.v_ptr); + printf("q_batch_stride: %" PRId64 "\n", params.q_batch_stride); + printf("k_batch_stride: %" PRId64 "\n", params.k_batch_stride); + printf("v_batch_stride: %" PRId64 "\n", params.v_batch_stride); + printf("q_row_stride: %" PRId64 "\n", params.q_row_stride); + printf("k_row_stride: %" PRId64 "\n", params.k_row_stride); + printf("v_row_stride: %" PRId64 "\n", params.v_row_stride); + printf("q_head_stride: %" PRId64 "\n", params.q_head_stride); + printf("k_head_stride: %" PRId64 "\n", params.k_head_stride); + printf("v_head_stride: %" PRId64 "\n", params.v_head_stride); + printf("h: %d\n", params.h); + printf("h_k: %d\n", params.h_k); + printf("h_h_k_ratio: %d\n", params.h_h_k_ratio); +} + +__forceinline__ __device__ +void print_flash_fwd_params(const Flash_fwd_params& params) { + print_qkv_params(params); + // LLM generated + printf("struct Flash_fwd_params:\n"); + printf("o_ptr: %p\n", params.o_ptr); + printf("oaccum_ptr: %p\n", params.oaccum_ptr); + printf("o_batch_stride: %ld\n", params.o_batch_stride); + printf("o_row_stride: %ld\n", params.o_row_stride); + printf("o_head_stride: %ld\n", params.o_head_stride); + printf("p_ptr: %p\n", params.p_ptr); + printf("softmax_lse_ptr: %p\n", params.softmax_lse_ptr); + printf("softmax_lseaccum_ptr: %p\n", params.softmax_lseaccum_ptr); + printf("b: %d\n", params.b); + printf("seqlen_q: %d\n", params.seqlen_q); + printf("seqlen_k: %d\n", params.seqlen_k); + printf("seqlen_knew: %d\n", params.seqlen_knew); + printf("d: %d\n", params.d); + printf("seqlen_q_rounded: %d\n", params.seqlen_q_rounded); + printf("seqlen_k_rounded: %d\n", params.seqlen_k_rounded); + printf("d_rounded: %d\n", params.d_rounded); + printf("rotary_dim: %d\n", params.rotary_dim); + printf("scale_softmax: %f\n", params.scale_softmax); + printf("scale_softmax_log2: %f\n", params.scale_softmax_log2); + printf("cu_seqlens_q: %p\n", params.cu_seqlens_q); + printf("cu_seqlens_k: %p\n", params.cu_seqlens_k); + printf("seqused_k: %p\n", params.seqused_k); + printf("blockmask: %p\n", params.blockmask); + printf("knew_ptr: %p\n", params.knew_ptr); + printf("vnew_ptr: %p\n", params.vnew_ptr); + printf("knew_batch_stride: %ld\n", params.knew_batch_stride); + printf("vnew_batch_stride: %ld\n", params.vnew_batch_stride); + printf("knew_row_stride: %ld\n", params.knew_row_stride); + printf("vnew_row_stride: %ld\n", params.vnew_row_stride); + printf("knew_head_stride: %ld\n", params.knew_head_stride); + printf("vnew_head_stride: %ld\n", params.vnew_head_stride); + printf("rotary_cos_ptr: %p\n", params.rotary_cos_ptr); + printf("rotary_sin_ptr: %p\n", params.rotary_sin_ptr); + printf("cache_batch_idx: %p\n", params.cache_batch_idx); + printf("block_table: %p\n", params.block_table); + printf("block_table_batch_stride: %ld\n", params.block_table_batch_stride); + printf("page_block_size: %d\n", params.page_block_size); + printf("p_dropout: %f\n", params.p_dropout); + printf("p_dropout_in_uint8_t: %u\n", params.p_dropout_in_uint8_t); + printf("rp_dropout: %f\n", params.rp_dropout); + printf("scale_softmax_rp_dropout: %f\n", params.scale_softmax_rp_dropout); + printf("window_size_left: %d\n", params.window_size_left); + printf("window_size_right: %d\n", params.window_size_right); + printf("philox_args: %p\n", &(params.philox_args)); + printf("rng_state: %p\n", params.rng_state); + printf("is_bf16: %d\n", params.is_bf16); + printf("is_causal: %d\n", params.is_causal); + printf("is_seqlens_k_cumulative: %d\n", params.is_seqlens_k_cumulative); + printf("is_rotary_interleaved: %d\n", params.is_rotary_interleaved); + printf("num_splits: %d\n", params.num_splits); + printf("alibi_slopes_ptr: %p\n", params.alibi_slopes_ptr); + printf("alibi_slopes_batch_stride: %ld\n", params.alibi_slopes_batch_stride); +} + template -__forceinline__ __device__ void -print_traits() { +__forceinline__ __device__ +void print_traits() { // bool printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false"); printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false"); @@ -36,7 +116,8 @@ print_traits() { printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize ); printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize ); printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize ); - printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread ); + printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread); + printf("Kernel_traits::kGmemThreadsPerRow: %d\n", Kernel_traits::kGmemThreadsPerRow); printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad ); // cute object diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index f2b28186c..204d8989c 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -43,9 +43,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; -#if 1 - KIN_PRINT(print_traits()); -#endif auto seed_offset = at::cuda::philox::unpack(params.philox_args); flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, @@ -60,9 +57,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; -#if 1 - KIN_PRINT(print_binfo(binfo)) -#endif const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); @@ -144,19 +138,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); -#if 1 - KIN_PRINT(print(sK.layout())) - KIN_PRINT(print(gK.layout())) -#endif Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); -#if 1 - KIN_PRINT(print(sV.layout())) - KIN_PRINT(print(sVt.layout())) - KIN_PRINT(print(sVtNoSwizzle.layout())) -#endif typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -167,27 +152,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); -#if 1 - KIN_PRINT(print(tKgK.layout())) - KIN_PRINT(print(tKsK.layout())) -#endif typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) -#if 1 - KIN_PRINT(print(tSrQ.layout())) - KIN_PRINT(print(tSrK.layout())) -#endif Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K -#if 1 - KIN_PRINT(print(acc_o.layout())) -#endif // // Copy Atom retiling @@ -195,22 +169,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); -#if 0 - KIN_PRINT(smem_thr_copy_Q.print_all()); -#endif // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); -#if 1 - KIN_PRINT(print(tSsQ.layout())) -#endif // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); -#if 1 - KIN_PRINT(print(tSsK.layout())) -#endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); @@ -227,10 +192,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -#if 1 - KIN_PRINT(print(cQ.layout())) - KIN_PRINT(print(cKV.layout())) -#endif // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); @@ -251,12 +212,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); -#if 1 - KIN_PRINT(print(tQcQ.layout())) - KIN_PRINT(print(tKVcKV.layout())) - KIN_PRINT(print(tQpQ.layout())) - KIN_PRINT(print(tKVpKV.layout())) -#endif // Set predicates for k bounds if (!Is_even_K) { @@ -538,13 +493,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons constexpr int kNWarps = Kernel_traits::kNWarps; #if 1 KIN_PRINT(print_traits()) - KIN_PRINT_BOOL(Is_causal) - KIN_PRINT_BOOL(Is_local) - KIN_PRINT_BOOL(Has_alibi) - KIN_PRINT_BOOL(Is_even_MN) - KIN_PRINT_BOOL(Is_even_K) - KIN_PRINT_BOOL(Split) - KIN_PRINT_BOOL(Append_KV) + KIN_PRINT(print_flash_fwd_params(params)) #endif using GmemTiledCopyO = std::conditional_t< @@ -558,9 +507,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; -#if 1 - KIN_PRINT(print_binfo(binfo)) -#endif const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_block_min = !Is_local @@ -625,17 +571,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + + const index_t row_offset_k__shadow = block_table[(n_block_max - 1) * kBlockN / params.page_block_size] * params.k_batch_stride + (((n_block_max - 1) * kBlockN) % params.page_block_size) * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); + Tensor gK__shadow = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k__shadow), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, @@ -646,13 +599,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); -#if 1 - KIN_PRINT(print(sK.layout())) - KIN_PRINT(print(gK.layout())) - KIN_PRINT(print(sV.layout())) - KIN_PRINT(print(sVt.layout())) - KIN_PRINT(print(sVtNoSwizzle.layout())) -#endif typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); @@ -662,27 +608,31 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK__shadow = gmem_thr_copy_KV.partition_S(gK__shadow); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); if (block_table != nullptr) { - tKgK.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); - tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } - #if 1 - KIN_PRINT(print(tKgK.layout())) - KIN_PRINT(print(tKsK.layout())) -#endif - -#if 1 - fill(tVgV, 1.f * ((Element) tidx)); - __syncthreads(); - - KIN_PRINT(print_tensor(gV)) + KIN_PRINT([&]() { + for (int i = 0; i < n_block_max; i++) { + printf("%d ", block_table[i]); + } + }()) + // if (tidx == 8) fill(tKgK, 1.f * tidx); + // if (thread0()) { + // gK.data() = tKgK.data(); + // } + KIN_PRINT(print_tensor(tKgK)) + KIN_PRINT(print_tensor(gK)) + KIN_PRINT(print_tensor(tKgK__shadow)) + KIN_PRINT(print_tensor(gK__shadow)) #endif typename Kernel_traits::TiledMma tiled_mma; @@ -690,15 +640,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) -#if 1 - KIN_PRINT(print(tSrQ.layout())) - KIN_PRINT(print(tSrK.layout())) -#endif Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K -#if 1 - KIN_PRINT(print(acc_o.layout())) -#endif // // Copy Atom retiling @@ -707,16 +650,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); -#if 1 - KIN_PRINT(print(tSsQ.layout())) -#endif auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); -#if 1 - KIN_PRINT(print(tSsK.layout())) -#endif auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); @@ -732,10 +669,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -#if 1 - KIN_PRINT(print(cQ.layout())) - KIN_PRINT(print(cKV.layout())) -#endif // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) @@ -744,12 +677,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); -#if 1 - KIN_PRINT(print(tQcQ.layout())) - KIN_PRINT(print(tKVcKV.layout())) - KIN_PRINT(print(tQpQ.layout())) - KIN_PRINT(print(tKVpKV.layout())) -#endif // Set predicates for k bounds if (!Is_even_K) { diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 48d460e2b..f21e98d35 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -4,6 +4,8 @@ #pragma once +#include "debug.h" + #include #include #include @@ -298,16 +300,17 @@ template __forceinline__ __device__ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { - // base col of thread's slice relative to the block - const int col_offset = tidx % Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemElemsPerLoad; - // base row of thread's slice relative to the block - const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; - // base col of thread's slice relative to the entire tensor - const int global_row_offset = block_row_offset + (n_block_max - 1) * Kernel_traits::kBlockN; - // base row of thread's slice relative to the page + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; + constexpr int kBlockN = Kernel_traits::kBlockN; + + const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; const int page_offset = global_row_offset % page_block_size; - const int virtual_page_idx = global_row_offset / page_block_size; + KIN_PRINT(printf("%d", virtual_page_idx)) return block_table[virtual_page_idx] * page_stride + page_offset * row_stride From 3691677702eed981256eb2dd4ac4b972f345076d Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 08:30:20 +0000 Subject: [PATCH 08/19] paged copy refactor working for page size 256 --- csrc/flash_attn/src/flash_fwd_kernel.h | 44 ++++++++++++++------------ csrc/flash_attn/src/utils.h | 13 +++++--- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 204d8989c..6304ccfda 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons block_table, params.v_batch_stride, params.v_row_stride); } #if 1 - KIN_PRINT([&]() { - for (int i = 0; i < n_block_max; i++) { - printf("%d ", block_table[i]); - } - }()) + // KIN_PRINT([&]() { + // for (int i = 0; i < n_block_max; i++) { + // printf("%d ", block_table[i]); + // } + // }()) // if (tidx == 8) fill(tKgK, 1.f * tidx); // if (thread0()) { // gK.data() = tKgK.data(); // } - KIN_PRINT(print_tensor(tKgK)) - KIN_PRINT(print_tensor(gK)) - KIN_PRINT(print_tensor(tKgK__shadow)) - KIN_PRINT(print_tensor(gK__shadow)) + // KIN_PRINT(print_tensor(tKgK)) + // KIN_PRINT(print_tensor(gK)) + // KIN_PRINT(print_tensor(tKgK__shadow)) + // KIN_PRINT(print_tensor(gK__shadow)) #endif typename Kernel_traits::TiledMma tiled_mma; @@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // const int offset_diff = block_table_offset_next - block_table_offset_cur; // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, - params.v_batch_stride, params.v_row_stride); - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, - params.k_batch_stride, params.k_row_stride); + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } } } @@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + // const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + // const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + // const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + // const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + // tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); } else { @@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, - params.k_batch_stride, params.k_row_stride); + tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization @@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index f21e98d35..51032b6e4 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; const int page_offset = global_row_offset % page_block_size; const int virtual_page_idx = global_row_offset / page_block_size; - KIN_PRINT(printf("%d", virtual_page_idx)) return block_table[virtual_page_idx] * page_stride + page_offset * row_stride @@ -324,12 +323,16 @@ template __forceinline__ __device__ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { - return 0; + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; + constexpr int kBlockN = Kernel_traits::kBlockN; + // base row of thread's slice relative to the block - const int block_row_offset = tidx / Kernel_traits::kGmemThreadsPerRow * Kernel_traits::kGmemRowsPerThread; + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; // base col of thread's slice relative to the entire tensor - const int global_row_offset_cur = block_row_offset + n_block * Kernel_traits::kBlockN; - const int global_row_offset_next = block_row_offset + (n_block - 1) * Kernel_traits::kBlockN; + const int global_row_offset_cur = block_row_offset + n_block * kBlockN; + const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; // base row of thread's slice relative to the page const int page_offset_cur = global_row_offset_cur % page_block_size; const int page_offset_next = global_row_offset_next % page_block_size; From c05b8570cb7524add6f2724a18a0b116c694b229 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 08:45:15 +0000 Subject: [PATCH 09/19] allow small page sizes in flash api --- csrc/flash_attn/flash_api.cpp | 2 +- csrc/flash_attn/src/utils.h | 5 ++--- tests/test_flash_attn.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 001acacaf..a22752551 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 51032b6e4..b7f8059f8 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; constexpr int kBlockN = Kernel_traits::kBlockN; - // base row of thread's slice relative to the block const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; - // base col of thread's slice relative to the entire tensor + const int global_row_offset_cur = block_row_offset + n_block * kBlockN; const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; - // base row of thread's slice relative to the page + const int page_offset_cur = global_row_offset_cur % page_block_size; const int page_offset_next = global_row_offset_next % page_block_size; diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 05f94490b..02a24cfca 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv( @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) -@pytest.mark.parametrize("paged_kv_block_size", [256]) +@pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) From 347a6253060957ac87b9bf4df1c4e677e18c8cfa Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 10:05:06 +0000 Subject: [PATCH 10/19] remove print statements --- csrc/flash_attn/src/debug.h | 141 ------------------------- csrc/flash_attn/src/flash_fwd_kernel.h | 21 ---- csrc/flash_attn/src/utils.h | 2 - 3 files changed, 164 deletions(-) delete mode 100644 csrc/flash_attn/src/debug.h diff --git a/csrc/flash_attn/src/debug.h b/csrc/flash_attn/src/debug.h deleted file mode 100644 index ed7f011b4..000000000 --- a/csrc/flash_attn/src/debug.h +++ /dev/null @@ -1,141 +0,0 @@ -#include -#include "block_info.h" - -#pragma once - -#define KIN_PRINT(statement) \ - if (thread0()) { \ - printf("\n[kin:start:%s]\n", #statement); \ - statement; \ - printf("\n[kin:end:%s]\n", #statement); \ - } - -#define KIN_PRINT_BOOL(BOOL) \ - if (thread0()) { \ - printf("\n[kin:start:%s]\n", #BOOL); \ - printf("%s", BOOL ? "true" : "false"); \ - printf("\n[kin:end:%s]\n", #BOOL); \ - } - -__forceinline__ __device__ -void print_qkv_params(const Qkv_params& params) { - // LLM generated - printf("Qkv_params:\n"); - printf("q_ptr: %p\n", params.q_ptr); - printf("k_ptr: %p\n", params.k_ptr); - printf("v_ptr: %p\n", params.v_ptr); - printf("q_batch_stride: %" PRId64 "\n", params.q_batch_stride); - printf("k_batch_stride: %" PRId64 "\n", params.k_batch_stride); - printf("v_batch_stride: %" PRId64 "\n", params.v_batch_stride); - printf("q_row_stride: %" PRId64 "\n", params.q_row_stride); - printf("k_row_stride: %" PRId64 "\n", params.k_row_stride); - printf("v_row_stride: %" PRId64 "\n", params.v_row_stride); - printf("q_head_stride: %" PRId64 "\n", params.q_head_stride); - printf("k_head_stride: %" PRId64 "\n", params.k_head_stride); - printf("v_head_stride: %" PRId64 "\n", params.v_head_stride); - printf("h: %d\n", params.h); - printf("h_k: %d\n", params.h_k); - printf("h_h_k_ratio: %d\n", params.h_h_k_ratio); -} - -__forceinline__ __device__ -void print_flash_fwd_params(const Flash_fwd_params& params) { - print_qkv_params(params); - // LLM generated - printf("struct Flash_fwd_params:\n"); - printf("o_ptr: %p\n", params.o_ptr); - printf("oaccum_ptr: %p\n", params.oaccum_ptr); - printf("o_batch_stride: %ld\n", params.o_batch_stride); - printf("o_row_stride: %ld\n", params.o_row_stride); - printf("o_head_stride: %ld\n", params.o_head_stride); - printf("p_ptr: %p\n", params.p_ptr); - printf("softmax_lse_ptr: %p\n", params.softmax_lse_ptr); - printf("softmax_lseaccum_ptr: %p\n", params.softmax_lseaccum_ptr); - printf("b: %d\n", params.b); - printf("seqlen_q: %d\n", params.seqlen_q); - printf("seqlen_k: %d\n", params.seqlen_k); - printf("seqlen_knew: %d\n", params.seqlen_knew); - printf("d: %d\n", params.d); - printf("seqlen_q_rounded: %d\n", params.seqlen_q_rounded); - printf("seqlen_k_rounded: %d\n", params.seqlen_k_rounded); - printf("d_rounded: %d\n", params.d_rounded); - printf("rotary_dim: %d\n", params.rotary_dim); - printf("scale_softmax: %f\n", params.scale_softmax); - printf("scale_softmax_log2: %f\n", params.scale_softmax_log2); - printf("cu_seqlens_q: %p\n", params.cu_seqlens_q); - printf("cu_seqlens_k: %p\n", params.cu_seqlens_k); - printf("seqused_k: %p\n", params.seqused_k); - printf("blockmask: %p\n", params.blockmask); - printf("knew_ptr: %p\n", params.knew_ptr); - printf("vnew_ptr: %p\n", params.vnew_ptr); - printf("knew_batch_stride: %ld\n", params.knew_batch_stride); - printf("vnew_batch_stride: %ld\n", params.vnew_batch_stride); - printf("knew_row_stride: %ld\n", params.knew_row_stride); - printf("vnew_row_stride: %ld\n", params.vnew_row_stride); - printf("knew_head_stride: %ld\n", params.knew_head_stride); - printf("vnew_head_stride: %ld\n", params.vnew_head_stride); - printf("rotary_cos_ptr: %p\n", params.rotary_cos_ptr); - printf("rotary_sin_ptr: %p\n", params.rotary_sin_ptr); - printf("cache_batch_idx: %p\n", params.cache_batch_idx); - printf("block_table: %p\n", params.block_table); - printf("block_table_batch_stride: %ld\n", params.block_table_batch_stride); - printf("page_block_size: %d\n", params.page_block_size); - printf("p_dropout: %f\n", params.p_dropout); - printf("p_dropout_in_uint8_t: %u\n", params.p_dropout_in_uint8_t); - printf("rp_dropout: %f\n", params.rp_dropout); - printf("scale_softmax_rp_dropout: %f\n", params.scale_softmax_rp_dropout); - printf("window_size_left: %d\n", params.window_size_left); - printf("window_size_right: %d\n", params.window_size_right); - printf("philox_args: %p\n", &(params.philox_args)); - printf("rng_state: %p\n", params.rng_state); - printf("is_bf16: %d\n", params.is_bf16); - printf("is_causal: %d\n", params.is_causal); - printf("is_seqlens_k_cumulative: %d\n", params.is_seqlens_k_cumulative); - printf("is_rotary_interleaved: %d\n", params.is_rotary_interleaved); - printf("num_splits: %d\n", params.num_splits); - printf("alibi_slopes_ptr: %p\n", params.alibi_slopes_ptr); - printf("alibi_slopes_batch_stride: %ld\n", params.alibi_slopes_batch_stride); -} - -template -__forceinline__ __device__ -void print_traits() { - // bool - printf("Kernel_traits::Share_Q_K_smem : %s\n", Kernel_traits::Share_Q_K_smem ? "true" : "false"); - printf("Kernel_traits::Is_Q_in_regs : %s\n", Kernel_traits::Is_Q_in_regs ? "true" : "false"); - - // int - printf("Kernel_traits::kNWarps : %d\n", Kernel_traits::kNWarps ); - printf("Kernel_traits::kNThreads : %d\n", Kernel_traits::kNThreads ); - printf("Kernel_traits::kBlockM : %d\n", Kernel_traits::kBlockM ); - printf("Kernel_traits::kBlockN : %d\n", Kernel_traits::kBlockN ); - printf("Kernel_traits::kHeadDim : %d\n", Kernel_traits::kHeadDim ); - printf("Kernel_traits::kBlockKSmem : %d\n", Kernel_traits::kBlockKSmem ); - printf("Kernel_traits::kBlockKGmem : %d\n", Kernel_traits::kBlockKGmem ); - printf("Kernel_traits::kSwizzle : %d\n", Kernel_traits::kSwizzle ); - printf("Kernel_traits::kSmemQSize : %d\n", Kernel_traits::kSmemQSize ); - printf("Kernel_traits::kSmemKVSize : %d\n", Kernel_traits::kSmemKVSize ); - printf("Kernel_traits::kSmemSize : %d\n", Kernel_traits::kSmemSize ); - printf("Kernel_traits::kGmemRowsPerThread: %d\n", Kernel_traits::kGmemRowsPerThread); - printf("Kernel_traits::kGmemThreadsPerRow: %d\n", Kernel_traits::kGmemThreadsPerRow); - printf("Kernel_traits::kGmemElemsPerLoad : %d\n", Kernel_traits::kGmemElemsPerLoad ); - - // cute object - printf("Kernel_traits::GmemLayoutAtom : "); - cute::print(Kernel_traits::GmemLayoutAtom()); - printf("\n"); - printf("Kernel_traits::GmemTiledCopyQKV :\n"); - cute::print(Kernel_traits::GmemTiledCopyQKV()); - printf("\n"); - -} - -template -__forceinline__ __device__ void -print_binfo(const BlockInfo& binfo) { - printf("binfo.sum_s_q : %d\n", binfo.sum_s_q); - printf("binfo.sum_s_k : %d\n", binfo.sum_s_k); - printf("binfo.actual_seqlen_q : %d\n", binfo.actual_seqlen_q); - printf("binfo.seqlen_k_cache : %d\n", binfo.seqlen_k_cache); - printf("binfo.actual_seqlen_k : %d\n", binfo.actual_seqlen_k); -} diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 6304ccfda..f638a4956 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -18,8 +18,6 @@ #include "dropout.h" #include "rotary.h" -#include "debug.h" - namespace flash { using namespace cute; @@ -491,10 +489,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; -#if 1 - KIN_PRINT(print_traits()) - KIN_PRINT(print_flash_fwd_params(params)) -#endif using GmemTiledCopyO = std::conditional_t< !Split, @@ -619,21 +613,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } -#if 1 - // KIN_PRINT([&]() { - // for (int i = 0; i < n_block_max; i++) { - // printf("%d ", block_table[i]); - // } - // }()) - // if (tidx == 8) fill(tKgK, 1.f * tidx); - // if (thread0()) { - // gK.data() = tKgK.data(); - // } - // KIN_PRINT(print_tensor(tKgK)) - // KIN_PRINT(print_tensor(gK)) - // KIN_PRINT(print_tensor(tKgK__shadow)) - // KIN_PRINT(print_tensor(gK__shadow)) -#endif typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index b7f8059f8..a838dad37 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -4,8 +4,6 @@ #pragma once -#include "debug.h" - #include #include #include From 3bb71a960b67bdbb1cbd23a6efcd149911ab2557 Mon Sep 17 00:00:00 2001 From: skrider Date: Sun, 11 Feb 2024 10:08:19 +0000 Subject: [PATCH 11/19] tidy flash_fwd_kernel --- csrc/flash_attn/src/flash_fwd_kernel.h | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index f638a4956..58c51f360 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -566,7 +566,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread - const index_t row_offset_k__shadow = block_table[(n_block_max - 1) * kBlockN / params.page_block_size] * params.k_batch_stride + (((n_block_max - 1) * kBlockN) % params.page_block_size) * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride @@ -580,9 +579,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); - Tensor gK__shadow = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k__shadow), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, @@ -602,7 +598,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgK__shadow = gmem_thr_copy_KV.partition_S(gK__shadow); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); @@ -754,14 +749,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - // const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - // const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - // const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - // const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - // const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - // const int offset_diff = block_table_offset_next - block_table_offset_cur; - // tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - // tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, @@ -854,11 +841,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - // const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - // const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - // const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - // const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - // tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } From fa13c6b06ec9d5abfe96ccc40a70f3051f7fb597 Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 13 Feb 2024 02:05:11 +0000 Subject: [PATCH 12/19] compiles for all h but 128 --- csrc/flash_attn/src/flash_fwd_kernel.h | 22 ++++++++++++++++------ csrc/flash_attn/src/utils.h | 9 ++++++++- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 58c51f360..44b9602d6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -597,10 +597,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); + + Tensor tKgK = make_tensor(tKgK_.data(), unsqueeze<2>(layout<0>(tKgK_.layout()))); + Tensor tKsK = make_tensor(tKsK_.data(), unsqueeze<2>(layout<0>(tKsK_.layout()))); + Tensor tVgV = make_tensor(tVgV_.data(), unsqueeze<2>(layout<0>(tVgV_.layout()))); + Tensor tVsV = make_tensor(tVsV_.data(), unsqueeze<2>(layout<0>(tVsV_.layout()))); if (block_table != nullptr) { tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, @@ -708,8 +713,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_KV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_KV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new; + auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx); + Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + auto tKgKnew = make_tensor(tKgKnew_.data(), unsqueeze<2>(layout<0>(tKgKnew_.layout()))); + auto tVgVnew = make_tensor(tVgVnew_.data(), unsqueeze<2>(layout<0>(tVgVnew_.layout()))); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index a838dad37..fc87f6f01 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -323,7 +323,6 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const const int* block_table, const int page_stride, const int row_stride) { constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; - constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; constexpr int kBlockN = Kernel_traits::kBlockN; const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; @@ -345,6 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ constexpr auto unsqueeze(Layout l) { + return make_layout(insert(l.shape(), Int<1>{}), + insert(l.stride(), Int<0>{})); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template From bde5aec8a1e526bcfc83b5daa0453dc38fdb1b6e Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 13 Feb 2024 06:40:32 +0000 Subject: [PATCH 13/19] all working except rotary embedding --- csrc/flash_attn/src/flash_fwd_kernel.h | 13 +++++++------ csrc/flash_attn/src/utils.h | 12 ++++++++---- tests/test_flash_attn.py | 25 ++++++++----------------- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 44b9602d6..4bb9d6bab 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -597,15 +597,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); - Tensor tKgK = make_tensor(tKgK_.data(), unsqueeze<2>(layout<0>(tKgK_.layout()))); - Tensor tKsK = make_tensor(tKsK_.data(), unsqueeze<2>(layout<0>(tKsK_.layout()))); - Tensor tVgV = make_tensor(tVgV_.data(), unsqueeze<2>(layout<0>(tVgV_.layout()))); - Tensor tVsV = make_tensor(tVsV_.data(), unsqueeze<2>(layout<0>(tVsV_.layout()))); + Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout())); + Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout())); + Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout())); + Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); if (block_table != nullptr) { tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, @@ -718,8 +719,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) - auto tKgKnew = make_tensor(tKgKnew_.data(), unsqueeze<2>(layout<0>(tKgKnew_.layout()))); - auto tVgVnew = make_tensor(tVgVnew_.data(), unsqueeze<2>(layout<0>(tVgVnew_.layout()))); + auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout())); + auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout())); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index fc87f6f01..52dc55141 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -344,10 +344,14 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__forceinline__ __device__ constexpr auto unsqueeze(Layout l) { - return make_layout(insert(l.shape(), Int<1>{}), - insert(l.stride(), Int<0>{})); +// somewhat unorthodox reshape function. Given a tuple ((v1, v2), m, k), returns (v1, v2, k), +// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures +// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. +template +__forceinline__ __device__ +auto reshape_thread_tile(Layout l) { + return make_layout(append(get<0>(l.shape()), get<2>(l.shape())), + append(get<0>(l.stride()), get<2>(l.stride()))); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 02a24cfca..ba8249f87 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1818,22 +1818,22 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) # @pytest.mark.parametrize("new_kv", [False]) -@pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) -@pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [16, 48, 256, 512]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @@ -1844,17 +1844,8 @@ def test_flash_attn_splitkv( @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), - (128, 128), + (1, 10 * 1024), + (16, 10 * 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) From bc668587a54f07da0f3d8e20e73f7b284e17ed46 Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 13 Feb 2024 08:06:22 +0000 Subject: [PATCH 14/19] add page size 16 to tests --- tests/test_flash_attn.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ba8249f87..bab67c401 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1818,24 +1818,24 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False]) -@pytest.mark.parametrize("alibi", [True]) +@pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) -@pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) -@pytest.mark.parametrize("paged_kv_block_size", [16, 48, 256, 512]) -# @pytest.mark.parametrize("has_batch_idx", [False, True]) -@pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512]) +@pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -1844,8 +1844,17 @@ def test_flash_attn_splitkv( @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (1, 10 * 1024), - (16, 10 * 1024), + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 1024), + (16, 128 * 1024), + (128, 128), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) From 0f5a45ea5303cae2f2acac1bf0201bb734561d0d Mon Sep 17 00:00:00 2001 From: skrider Date: Mon, 26 Feb 2024 07:27:54 +0000 Subject: [PATCH 15/19] reshape rotary sin/cos copy to align with paged KV copy --- csrc/flash_attn/src/flash_fwd_kernel.h | 31 ++++++++++++++++++-------- csrc/flash_attn/src/kernel_traits.h | 12 +++++++++- csrc/flash_attn/src/utils.h | 13 ++++++++++- tests/test_flash_attn.py | 2 +- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 4bb9d6bab..9ae7dd27e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -652,7 +652,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout())); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -669,11 +670,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { + typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. @@ -690,10 +692,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout())); + Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout())); + Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout())); + Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout())); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } @@ -779,6 +788,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 5d6cab9d9..8556bdd46 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -158,7 +158,9 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store - using GmemLayoutAtomRotcossin = GmemLayoutAtom; + // using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemLayoutAtomRotcossin = Layout, Int>, + Stride, _1>>; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, @@ -167,6 +169,14 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load + using GmemTiledCopyRotcossinPaged = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinContPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 52dc55141..46b2ea039 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -344,7 +344,7 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const //////////////////////////////////////////////////////////////////////////////////////////////////// -// somewhat unorthodox reshape function. Given a tuple ((v1, v2), m, k), returns (v1, v2, k), +// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k), // where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures // that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. template @@ -354,6 +354,17 @@ auto reshape_thread_tile(Layout l) { append(get<0>(l.stride()), get<2>(l.stride()))); } +// reshapes and flattens the thread tile layout. A separate function is needed for the case where +// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact +// for the case of swizzled layouts +template +__forceinline__ __device__ +auto reshape_flatten_thread_tile(Layout l) { + auto mode_0 = filter(flatten(get<0>(l))); + return make_layout(append(mode_0.shape(), get<2>(l.shape())), + append(mode_0.stride(), get<2>(l.stride()))); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// template Date: Mon, 26 Feb 2024 07:31:16 +0000 Subject: [PATCH 16/19] revert hardcoded rotcossin thread layout --- csrc/flash_attn/src/kernel_traits.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 8556bdd46..04a9b3b29 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -158,9 +158,7 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store - // using GmemLayoutAtomRotcossin = GmemLayoutAtom; - using GmemLayoutAtomRotcossin = Layout, Int>, - Stride, _1>>; + using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, From 135a1da6138a4bee61f91d7101693f699d63f780 Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 26 Mar 2024 01:33:51 +0000 Subject: [PATCH 17/19] resolve page offsets absolutely not relatively --- csrc/flash_attn/src/flash_fwd_kernel.h | 16 ++++++------- csrc/flash_attn/src/utils.h | 32 ++------------------------ 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 9ae7dd27e..34922d519 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); if (block_table != nullptr) { - tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); - tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } @@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } } @@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 46b2ea039..4f999a6b7 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -292,11 +292,11 @@ void cp_async_wait() { //////////////////////////////////////////////////////////////////////////////////////////////////// -// resolves initial base offset of a slice of a paged kv copy from gmem. +// resolves offset of a slice of a paged kv copy from gmem. // assumes that the tensor has already been positioned at the correct head. template __forceinline__ __device__ -int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, +int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; @@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons + page_offset * row_stride + col_offset; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// advances base address of a slice of a paged copy from gmem -template -__forceinline__ __device__ -int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, - const int* block_table, const int page_stride, const int row_stride) { - constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; - constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; - constexpr int kBlockN = Kernel_traits::kBlockN; - - const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; - - const int global_row_offset_cur = block_row_offset + n_block * kBlockN; - const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; - - const int page_offset_cur = global_row_offset_cur % page_block_size; - const int page_offset_next = global_row_offset_next % page_block_size; - - const int virtual_page_idx_cur = global_row_offset_cur / page_block_size; - const int virtual_page_idx_next = global_row_offset_next / page_block_size; - - const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur]; - const int offset_diff = page_offset_next - page_offset_cur; - - return table_diff * page_stride + offset_diff * row_stride; -} //////////////////////////////////////////////////////////////////////////////////////////////////// From a63157ea8c872b568fde133f58dea6e061b3e7a2 Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 26 Mar 2024 02:10:02 +0000 Subject: [PATCH 18/19] add test for page table overflow --- tests/test_flash_attn.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 17859cae8..7354ab312 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2461,3 +2461,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv) assert torch.equal(dk, dk) assert torch.equal(dq, dq) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("paged_kv_block_size", [16]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("nheads", [32]) +@pytest.mark.parametrize("b", [4]) +@pytest.mark.parametrize("n", [10]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(170, 170)]) +def test_flash_attn_paged_kvcache_overflow( + seqlen_q, + seqlen_k, + d, + nheads, + b, + n, + paged_kv_block_size, + causal, + dtype, +): + device = "cuda" + num_blocks = 1000*16//paged_kv_block_size + key_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device) + value_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device) + cache_seqlens = torch.zeros(b, dtype=torch.int32, device=device) + + for _ in range(n): + query = torch.rand([b, seqlen_q, nheads, d], dtype=dtype, device=device) + key = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device) + value = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device) + block_tables = torch.randint(0, num_blocks, size=(b, (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size), dtype=torch.int32, device=device) + output = flash_attn_with_kvcache( + query, + key_cache, + value_cache, + k=key, + v=value, + cache_seqlens=cache_seqlens, + block_table=block_tables, + causal=causal, + ) From 7968148214850ede9fa4bd515316c643ebd8ae83 Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 26 Mar 2024 06:33:40 +0000 Subject: [PATCH 19/19] allow smaller page sizes in varlen api --- csrc/flash_attn/flash_api.cpp | 2 +- tests/test_flash_attn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index a22752551..75ba3ed22 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 7354ab312..65af6af10 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ], ) # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged -@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype