diff --git a/src/kernels/attention/attention_bench_sm80.cu b/src/kernels/attention/attention_bench_sm80.cu index 6cd7be6d..53c2c250 100644 --- a/src/kernels/attention/attention_bench_sm80.cu +++ b/src/kernels/attention/attention_bench_sm80.cu @@ -74,8 +74,8 @@ void attention_bench_sm80(nvbench::state& state) { state.exec([&](nvbench::launch& launch) { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - run_attention_kernel_sm80(params, - launch.get_stream()); + run_attention_kernel_sm80( + params, launch.get_stream()); }); }); } diff --git a/src/kernels/attention/attention_kernel_sm80.cuh b/src/kernels/attention/attention_kernel_sm80.cuh index ef6f959b..7db50ba7 100644 --- a/src/kernels/attention/attention_kernel_sm80.cuh +++ b/src/kernels/attention/attention_kernel_sm80.cuh @@ -10,6 +10,7 @@ #include "cute/config.hpp" #include "cute_extensions.cuh" #include "fast_cast.cuh" +#include "layout_conformance.cuh" #include "mask.h" #include "online_softmax.cuh" #include "ptx.cuh" @@ -38,6 +39,7 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // type alias using DType = typename Traits::DType; + using KV_DType = typename Traits::KV_DType; using TiledMma = typename Traits::TiledMma; using Layout = typename Traits::LayoutConvertor; @@ -69,10 +71,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto [Q, O] = tile.template get_qo_tile(batch_idx, head_idx); // (kv_len, HEAD_DIM) auto [K, V] = - tile.template get_kv_tile(batch_idx, head_idx / group_size); + tile.template get_kv_tile(batch_idx, head_idx / group_size); - const int q_len = size<0>(Q.shape()); - const int kv_len = size<0>(K.shape()); + const int q_len = size<0>(Q); + const int kv_len = size<0>(K); if (m_block * kBlockM >= q_len) { // out of bound, return @@ -110,8 +112,8 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // Smem extern __shared__ char smem[]; DType* q_smem = (DType*)smem; - DType* k_smem = q_smem + cosize(SmemLayoutQ{}); - DType* v_smem = k_smem + cosize(SmemLayoutK{}); + KV_DType* k_smem = q_smem + cosize(SmemLayoutQ{}); + KV_DType* v_smem = k_smem + cosize(SmemLayoutK{}); // (BLK_M, BLK_K), k-major Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); @@ -141,10 +143,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_q = [&]() { auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - safe_copy( + safe_copy( gmem_tiled_copy_Q, tQgQ, tQsQ, @@ -157,37 +159,36 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { auto produce_k = [&](int ni) { auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); // skip zero fill oob for k since mask will mask out oob with -inf - safe_copy( - gmem_tiled_copy_Q, - tKgK, - tKsK, - tKcKV, - make_coord(kv_len - ni * kBlockN, head_dim)); + safe_copy(gmem_tiled_copy_Q, + tKgK, + tKsK, + tKcKV, + make_coord(kv_len - ni * kBlockN, head_dim)); }; Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); auto produce_v = [&](int ni) { auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); // TODO: skip zero fill oob for v, may have nan issue - safe_copy( - gmem_tiled_copy_Q, - tVgV, - tVsV, - tKcKV, - make_coord(kv_len - ni * kBlockN, head_dim)); + safe_copy(gmem_tiled_copy_Q, + tVgV, + tVsV, + tKcKV, + make_coord(kv_len - ni * kBlockN, head_dim)); }; TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(tidx); // GEMM-I: S = Q@K.T auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + // alocate register with type KV_DType for K + auto tSrK = make_fragment_B(thr_mma, sK); // (MMA,MMA_N,MMA_K) // s2r tiled copy for qkv SmemTiledCopyQ smem_tiled_copy_Q; @@ -219,12 +220,18 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { tSsK(_, _, next_ki), tSrK_copy_view(_, _, next_ki)); } - cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); + + if constexpr (!is_same_v) { + // fragment layout swizzle between threads + frag_B_layout_swizzle(tSrK_copy_view, tidx); + } + mixed_gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); } }; // GEMM-II: O = softmax(S)@V - auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) + // alocate register with type KV_DType for V^t + auto tOrVt = make_fragment_B(thr_mma, sVt); // (MMA,MMA_K,MMA_N) SmemTiledCopyVt smem_tiled_copy_Vt; auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); @@ -254,7 +261,12 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { tOsVt(_, _, next_ki), tOrVt_copy_view(_, _, next_ki)); } - cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); + + if constexpr (!is_same_v) { + // fragment layout swizzle between threads + frag_B_trans_layout_swizzle(tOrVt_copy_view, tidx); + } + mixed_gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); } }; @@ -288,10 +300,10 @@ __global__ void mha_kernel_sm80(__grid_constant__ const Params params) { // wait for smem copy done before gmem copy __syncthreads(); - safe_copy( + safe_copy( gmem_tiled_copy_O, tOsO, tOgO, diff --git a/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu b/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu index e3574b04..33cd34db 100644 --- a/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu @@ -63,7 +63,7 @@ torch::Tensor attention_pagedkv_sm80( DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { - run_attention_kernel_sm80(params); + run_attention_kernel_sm80(params); }); }); return out; diff --git a/src/kernels/attention/attention_kernel_sm80_test.cu b/src/kernels/attention/attention_kernel_sm80_test.cu index f78c9e72..636aa495 100644 --- a/src/kernels/attention/attention_kernel_sm80_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_test.cu @@ -59,7 +59,7 @@ torch::Tensor attention_sm80( DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { - run_attention_kernel_sm80(params); + run_attention_kernel_sm80(params); }); }); return out; diff --git a/src/kernels/attention/attention_kernel_sm80_varlen_test.cu b/src/kernels/attention/attention_kernel_sm80_varlen_test.cu index e1eee60a..d091f7fc 100644 --- a/src/kernels/attention/attention_kernel_sm80_varlen_test.cu +++ b/src/kernels/attention/attention_kernel_sm80_varlen_test.cu @@ -156,7 +156,7 @@ torch::Tensor attention_varlen_sm80( DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] { - run_attention_kernel_sm80(params); + run_attention_kernel_sm80(params); }); }); return out; diff --git a/src/kernels/attention/attention_launch_sm80.cuh b/src/kernels/attention/attention_launch_sm80.cuh index fecb39b5..0f08ef8a 100644 --- a/src/kernels/attention/attention_launch_sm80.cuh +++ b/src/kernels/attention/attention_launch_sm80.cuh @@ -51,35 +51,39 @@ void run_attention_kernel(const Params& params, cudaStream_t stream) { } // namespace detail // user-facing function to run the attention kernel -template +template void run_attention_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { // normalize params that for performance optimization params.normalize(); // TODO: tune block shape MNK based on the head dim and smem size if constexpr (HEAD_DIM == 64) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 96) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 128) { - using Traits = AttentionTraitsSM80; detail::run_attention_kernel(params, stream); } else if constexpr (HEAD_DIM == 256) { - using Traits = AttentionTraitsSM80(params, stream); } else { // use the default block size - using Traits = AttentionTraitsSM80 +CUTE_HOST_DEVICE constexpr auto tiled_mma_selector() { + CUTE_STATIC_ASSERT( + sizeof_bits_v == 16 || sizeof_bits_v == 8, + "KV_DType must be 8 or 16 bits"); + + using MMA_Atom = std::conditional_t, + MMA_Atom, + MMA_Atom>; + using TiledMma = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + return TiledMma{}; +} + +template +CUTE_HOST_DEVICE constexpr auto tiled_copy_selector() { + using DType = typename COPY_Atom::ValType; + // use 128 bits vectorized copy + constexpr int kValPerThr = 128 / sizeof_bits_v; + constexpr int kThrsPerRow = BLK_K / kValPerThr; + using ThrLayout = Layout, Int>, + Stride, _1>>; + using ValLayout = Layout>>; + return make_tiled_copy(COPY_Atom{}, ThrLayout{}, ValLayout{}); +} + +template +CUTE_HOST_DEVICE constexpr auto tiled_copy_B_selector() { + if constexpr (sizeof_bits_v == 16) { + // ((_4,_8,_4),((_2,_2),(_2,_1))):((_32,_1,_0),((_16,_128),(_8,_0))) + return make_tiled_copy_B(Copy_Atom{}, + TiledMma{}); + } else if constexpr (sizeof_bits_v == 8) { + // ((_4, _8), (_4, _2)):((_64, _1), (_16, _8)) + using Layout_TV_K = Layout, Shape<_4, _2>>, + Stride, Stride<_16, _8>>>; + // use cute::uint8_t as InternalType + using SmemTiledCopyK = + TiledCopy, + Layout_TV_K, + Shape<_16, _16>>; // N x K + return SmemTiledCopyK{}; + } else { + CUTE_STATIC_ASSERT( + sizeof_bits_v == 8 || sizeof_bits_v == 16, + "KV_DType must be 8 or 16 bits"); + } +} + +template +CUTE_HOST_DEVICE constexpr auto tiled_copy_B_T_selector() { + if constexpr (sizeof_bits_v == 16) { + // ((_4,_8,_4),((_2,_2),(_2,_1))):((_32,_1,_0),((_16,_128),(_8,_0))) + return make_tiled_copy_B(Copy_Atom{}, TiledMma{}); + } else if constexpr (sizeof_bits_v == 8) { + // ((_4, _8), (_2, _2, _2)):((_32, _2), (_1, _16, _128)) + using Layout_TV_Vt = Layout, Shape<_2, _2, _2>>, + Stride, Stride<_1, _16, _128>>>; + // use cute::uint8_t as InternalType + using SmemTiledCopyVt = + TiledCopy, + Layout_TV_Vt, + Shape<_16, _16>>; // K x N + return SmemTiledCopyVt{}; + } else { + CUTE_STATIC_ASSERT_V( + sizeof_bits_v == 8 || sizeof_bits_v == 16, + "DType must be 8 or 16 bits"); + } +} + } // namespace detail -template +template struct AttentionTraitsSM80 { // helpful aliases static constexpr int kHeadDim = HEAD_DIM; @@ -41,6 +118,7 @@ struct AttentionTraitsSM80 { static constexpr int kRowsPerMMA = 2; using DType = DTYPE; + using KV_DType = KV_DTYPE; using _BLK_M = Int; using _BLK_N = Int; using _BLK_K = Int; @@ -49,13 +127,8 @@ struct AttentionTraitsSM80 { // ******* Mainloop ******* // TiledMMA (64x16x16) for gemm-I and gemm-II // choose MMA_Atom based on Element type - using MMA_Atom_ = - std::conditional_t, - MMA_Atom, - MMA_Atom>; - using TiledMma = TiledMMA>, // warp layout 4x1x1 - Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + using TiledMma = decltype(detail::tiled_mma_selector()); + static constexpr size_t kThreadNum = size(TiledMma{}); // Layout convertor for TiledMMA (64x16x16) using LayoutConvertor = detail::LayoutConvertor; @@ -77,40 +150,39 @@ struct AttentionTraitsSM80 { using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - // V^T smem: (HEAD_DIM, BLK_N) row-major + // V^T smem (transpose view of V): (HEAD_DIM, BLK_N) using SmemLayoutVt = decltype(composition( SmemLayoutV{}, make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); - // Thr layout for gmem copy - using GmemCopyThrLayout = - std::conditional_t, Stride<_4, _1>>, - Layout, Stride<_8, _1>>>; - // Tiled copy for QKV // g2s tiled copy for q - using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); + using GmemTiledCopyQ = + decltype(detail::tiled_copy_selector< + Copy_Atom, + DType>, + BLK_K, + kThreadNum>()); // g2s tiled copy for kv - using GmemTiledCopyKV = GmemTiledCopyQ; + // TODO: choose based on BLK_K and kv cache type + using GmemTiledCopyKV = + decltype(detail::tiled_copy_selector< + Copy_Atom, + KV_DType>, + BLK_K, + kThreadNum>()); // s2r tiled copy for gemm-I using SmemTiledCopyQ = decltype(make_tiled_copy_A(Copy_Atom{}, TiledMma{})); using SmemTiledCopyK = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); + decltype(detail::tiled_copy_B_selector()); // s2r tiled copy for gemm-II using SmemTiledCopyVt = - decltype(make_tiled_copy_B(Copy_Atom{}, - TiledMma{})); + decltype(detail::tiled_copy_B_T_selector()); // ******* Epilogue ******* @@ -118,11 +190,10 @@ struct AttentionTraitsSM80 { using SmemLayoutO = SmemLayoutQ; // s2g tiled copy for O - using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom{}, - GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) - Layout>{} // Val layout: 8 vals per read - )); + using GmemTiledCopyO = + decltype(detail::tiled_copy_selector, + BLK_K, + kThreadNum>()); // r2s tiled copy for O using SmemTiledCopyO = @@ -130,10 +201,8 @@ struct AttentionTraitsSM80 { // constexpr values for kernel launch static constexpr size_t kSmemSize = - (cosize(SmemLayoutQ{}) + cosize(SmemLayoutK{}) + cosize(SmemLayoutV{})) * - sizeof(DType); - - static constexpr size_t kThreadNum = size(TiledMma{}); + cosize(SmemLayoutQ{}) * sizeof(DType) + + (cosize(SmemLayoutK{}) + cosize(SmemLayoutV{})) * sizeof(KV_DType); }; } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/attention_traits_test.cpp b/src/kernels/attention/attention_traits_test.cpp index 14d39975..75c42398 100644 --- a/src/kernels/attention/attention_traits_test.cpp +++ b/src/kernels/attention/attention_traits_test.cpp @@ -47,6 +47,7 @@ void test_attention_traits() { TEST(AttentionTraitsTest, TraitsSM80) { test_attention_traits +CUTE_HOST_DEVICE constexpr auto make_fragment_B(const ThrMMA& thr_mma, + BTensor const& btensor) { + return make_fragment_like(thr_mma.partition_B(btensor)); +} + template CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a, IntTupleB const& b) { return elem_less(get(a), get(b)); } -template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD&& dst) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + + auto has_with_bool = cute::is_valid( + [](auto t) -> void_t() + .with(true))> {}, + copy_atom); + if constexpr (has_with_bool) { + constexpr int R = TensorD::rank; + if constexpr (R == 1) { // Dispatch the copy + copy_atom.with(false).call(src, dst); + } else { // Loop over all but the first mode + Tensor src_v = group_modes<1, R>(src); + Tensor dst_v = group_modes<1, R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.with(false).call(src_v(_, i), dst_v(_, i)); + } + } + } else { + // just call clear if no with method + clear(dst); + } +} + +template +CUTE_HOST_DEVICE void zfill(const Copy_Atom& copy_atom, + const TensorS& src, + TensorD& dst) { + zfill(copy_atom, src, dst); +} + +template CUTE_HOST_DEVICE void safe_copy( - const TiledCopy& tiled_copy, + const TiledCopy& tiled_copy, const TensorS& src, // (CPY, CPY_M/N, CPY_K) TensorD& dst, // (CPY, CPY_M/N, CPY_K) const TensorC& identity, // (CPY, CPY_M/N, CPY_K) -> (blk_m/n, blk_k) const Coord& max_coord // max_coord(blk_m/n, blk_k) ) { + CUTE_STATIC_ASSERT(TensorS::rank == TensorD::rank, "rank-mismatch."); + auto copy_atom = static_cast(tiled_copy); + if constexpr (!EVEN_MN && !EVEN_K) { // handle both m/n and k oob CUTE_UNROLL @@ -39,16 +86,16 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, mi, ki), dst(_, mi, ki)); + copy(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, mi, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, mi, ki), dst(_, mi, ki)); } } } } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -57,10 +104,10 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int mi = 0; mi < size<1>(src); ++mi) { if (elem_less<0>(identity(_0{}, mi, _0{}), max_coord)) { - copy(tiled_copy, src(_, mi, _), dst(_, mi, _)); + copy(copy_atom, src(_, mi, _), dst(_, mi, _)); } else { - if constexpr (ZERO_FILL_MN) { - clear(dst(_, mi, _)); + if constexpr (ZFILL_MN) { + zfill(copy_atom, src(_, mi, _), dst(_, mi, _)); } } } @@ -69,16 +116,51 @@ CUTE_HOST_DEVICE void safe_copy( CUTE_UNROLL for (int ki = 0; ki < size<2>(src); ++ki) { if (elem_less<1>(identity(_0{}, _0{}, ki), max_coord)) { - copy(tiled_copy, src(_, _, ki), dst(_, _, ki)); + copy(copy_atom, src(_, _, ki), dst(_, _, ki)); } else { - if constexpr (ZERO_FILL_K) { - clear(dst(_, _, ki)); + if constexpr (ZFILL_K) { + zfill(copy_atom, src(_, _, ki), dst(_, _, ki)); } } } } else { // no oob, just copy - copy(tiled_copy, src, dst); + copy(copy_atom, src, dst); + } +} + +// support mixed precision mma +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template +CUTE_HOST_DEVICE void mixed_gemm(MMA_Atom const& mma, + const FragmentA& A, // (V,M) Logical data + const FragmentB& B, // (V,N) Logical data + FragmentC& C) // (V,M,N) Logical data +{ + using AType = typename FragmentA::value_type; + using BType = typename FragmentB::value_type; + + if constexpr (std::is_same_v) { + // same type, call gemm + gemm(mma, A, B, C); + } else { + // handle mixed precision + auto M = size<1>(A); + auto N = size<1>(B); + + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + // Covnert B to same type as A before gemm + auto B_ = make_fragment_like(B(_, n)); + fast_cast(B(_, n), B_); + + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M - 1 - m : m; // Serpentine coordinate + gemm(mma, A(_, ms), B_, C(_, ms, n)); + } + } } } diff --git a/src/kernels/attention/fast_cast.cuh b/src/kernels/attention/fast_cast.cuh index 76642dfa..822888ab 100644 --- a/src/kernels/attention/fast_cast.cuh +++ b/src/kernels/attention/fast_cast.cuh @@ -46,7 +46,34 @@ struct type_cast { } } }; -// TODO: add other specializations + +template <> +struct type_cast { + template + CUTE_DEVICE static void cast(const FragmentS& src, FragmentD& dst) { + // TODO: implement fast float_e4m3_t -> half_t + CUTE_UNROLL + for (int i = 0; i < size(src); ++i) { + dst(i) = cute::half_t(src(i)); + } + } +}; + +template <> +struct type_cast { + template + CUTE_DEVICE static void cast(const FragmentS& src, FragmentD& dst) { + // TODO: implement fast float_e5m2_t -> half_t + CUTE_UNROLL + for (int i = 0; i < size(src); ++i) { + dst(i) = cute::half_t(src(i)); + } + } +}; + +// TODO: implement the following specializations +// specialization for float_e4m3_t -> bfloat16 +// specialization for float_e5m2_t -> bfloat16 } // namespace detail diff --git a/src/kernels/attention/generate_instantiation_cu.py b/src/kernels/attention/generate_instantiation_cu.py index 34604151..fedc11de 100755 --- a/src/kernels/attention/generate_instantiation_cu.py +++ b/src/kernels/attention/generate_instantiation_cu.py @@ -12,6 +12,10 @@ "bf16": "cute::bfloat16_t", } +# TODO: add support for mixed precision kernels +KV_DTYPE_MAP = { +} + HEAD_DIMENSIONS = [64, 96, 128, 256] PAGEDKV_KERNEL_IMPL_TEMPLATE = """ @@ -20,7 +24,7 @@ namespace llm {{ using Params = PagedKVAttentionParams; -template void run_attention_kernel_sm80<{DTYPE}, {HEAD_DIM}, Params>( +template void run_attention_kernel_sm80<{DTYPE}, {KV_DTYPE}, {HEAD_DIM}, Params>( Params& params, cudaStream_t stream); }} // namespace llm @@ -29,22 +33,31 @@ @dataclass class Kernel: dtype: str + kv_dtype: str head_dim: int @property def template(self) -> str: return PAGEDKV_KERNEL_IMPL_TEMPLATE.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim + DTYPE=DTYPE_MAP[self.dtype], KV_DTYPE=DTYPE_MAP[self.kv_dtype], HEAD_DIM=self.head_dim ) @property def filename(self) -> str: - return f"attention_{self.dtype}_hd{self.head_dim}_sm80.cu" + if self.dtype == self.kv_dtype: + return f"attention_{self.dtype}_hd{self.head_dim}_sm80.cu" + # include the kv dtype in the filename + return f"attention_{self.dtype}_{self.kv_dtype}_hd{self.head_dim}_sm80.cu" def get_all_kernels() -> Iterator[Kernel]: + # fp16 and bf16 kernels for dtype, head_dim in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS): - yield Kernel(dtype=dtype, head_dim=head_dim) + yield Kernel(dtype=dtype, kv_dtype=dtype, head_dim=head_dim) + + # mixed precision kernels + for dtype, kv_dtype, head_dim in itertools.product(DTYPE_MAP.keys(), KV_DTYPE_MAP.keys(), HEAD_DIMENSIONS): + yield Kernel(dtype=dtype, kv_dtype=kv_dtype, head_dim=head_dim) if __name__ == "__main__": diff --git a/src/kernels/attention/layout_conformance.cuh b/src/kernels/attention/layout_conformance.cuh new file mode 100644 index 00000000..faaeb825 --- /dev/null +++ b/src/kernels/attention/layout_conformance.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include + +namespace llm { + +namespace detail { +using namespace cute; + +CUTE_DEVICE void swap(uint16_t& a, uint16_t& b) { + auto tmp = a; + a = b; + b = tmp; +} + +// adapted from https://github.com/flashinfer-ai/flashinfer + +// T0( 0, 1, 2, 3) => T0( 0, 1, 8, 9) +// T1( 4, 5, 6, 7) => T1( 2, 3, 10, 11) +// T2( 8, 9, 10, 11) => T2( 4, 5, 12, 13) +// T3( 12, 13, 14, 15) => T3( 6, 7, 14, 15) +CUTE_DEVICE uint32_t frag_B_layout_swizzle_8b(uint32_t x, int tidx) { + uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); + x = __byte_perm(x, tmp, ((tidx & 0x1) == 0) ? 0x5410 : 0x3276); + tmp = __shfl_xor_sync(0xffffffff, x, 0x2); + x = __byte_perm(x, tmp, ((tidx & 0x2) == 0) ? 0x5410 : 0x3276); + return x; +} + +// T0: ( 0, 16, 1, 17) => T0( 0, 1, 128, 129) +// T4: ( 32, 48, 33, 49) => T4( 16, 17, 144, 145) +// T8: ( 64, 80, 65, 81) => T8( 32, 33, 160, 161) +// T12:( 96, 112, 97, 113) => T12( 48, 49, 176, 177) +// T16:(128, 144, 129, 145) => T16( 64, 65, 192, 193) +// T20:(160, 176, 161, 177) => T20( 80, 81, 208, 209) +// T24:(192, 208, 193, 209) => T24( 96, 97, 224, 225) +// T28:(224, 240, 225, 241) => T28(112, 113, 240, 241) +CUTE_DEVICE uint32_t frag_B_trans_layout_swizzle_8b(uint32_t x, int tidx) { + uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); + x = __byte_perm(x, tmp, ((tidx & 0x4) == 0) ? 0x6420 : 0x3175); + tmp = __shfl_xor_sync(0xffffffff, x, 0x8); + x = __byte_perm(x, tmp, ((tidx & 0x8) == 0) ? 0x5410 : 0x3276); + tmp = __shfl_xor_sync(0xffffffff, x, 0x10); + x = __byte_perm(x, tmp, ((tidx & 0x10) == 0) ? 0x5410 : 0x3276); + return x; +} +} // namespace detail + +// TODO: arrange elements for one thread together in quatatization stage to +// avoid shfl cost + +// frag: (CPY,CPY_N,CPY_K) +template +CUTE_DEVICE void frag_B_layout_swizzle(FragmentB& frag, int tidx) { + // ? not sure if this cast is expensive ? + auto frag_32 = cute::recast(frag); + CUTE_UNROLL + for (int i = 0; i < size(frag_32); ++i) { + frag_32[i] = detail::frag_B_layout_swizzle_8b(frag_32[i], tidx); + } +} + +// frag: (CPY,CPY_K,CPY_N) +template +CUTE_DEVICE void frag_B_trans_layout_swizzle(FragmentB& frag, int tidx) { + auto frag_32 = cute::recast(frag); + CUTE_UNROLL + for (int i = 0; i < size(frag_32); ++i) { + frag_32[i] = detail::frag_B_trans_layout_swizzle_8b(frag_32[i], tidx); + } + + auto frag_16 = cute::recast(frag); + CUTE_UNROLL + for (int i = 0; i < cute::size(frag_16); i += 4) { + // swap 16-bit pair: V0, *V1, *V2, V3 => V0, *V2, *V1, V3 + swap(frag_16(i + 1), frag_16(i + 2)); + } +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/tools/attention_traits_viewer.cpp b/src/kernels/attention/tools/attention_traits_viewer.cpp index e28376ca..5d7c01d1 100644 --- a/src/kernels/attention/tools/attention_traits_viewer.cpp +++ b/src/kernels/attention/tools/attention_traits_viewer.cpp @@ -2,6 +2,8 @@ #include #include "../attention_traits_sm80.h" +#include "../cute_extensions.cuh" +#include "cute/numeric/numeric_types.hpp" #include "print_svg.hpp" using namespace cute; @@ -17,7 +19,6 @@ template void print_attn_traits() { // type alias using TiledMma = typename Traits::TiledMma; - using Layout = typename Traits::LayoutConvertor; using SmemLayoutQ = typename Traits::SmemLayoutQ; using SmemLayoutK = typename Traits::SmemLayoutK; @@ -114,9 +115,76 @@ void print_attn_traits() { "smem_layout_o.svg", SmemLayoutO{}, SmemTiledCopyO{}, GmemTiledCopyO{}); } +template +void test_attn_traits() { + // type alias + using DType = typename Traits::DType; + using KV_DType = typename Traits::KV_DType; + using TiledMma = typename Traits::TiledMma; + + using SmemLayoutQ = typename Traits::SmemLayoutQ; + using SmemLayoutK = typename Traits::SmemLayoutK; + using SmemLayoutV = typename Traits::SmemLayoutV; + using SmemLayoutVt = typename Traits::SmemLayoutVt; + using SmemLayoutO = typename Traits::SmemLayoutO; + + using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; + using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; + using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + + // NxK: (64, 64) + Tensor sK = make_tensor(counting_iterator(0), SmemLayoutK{}); + Tensor sVt = make_tensor(counting_iterator(0), SmemLayoutVt{}); + print("sk: "); + print(sK); + print("\n"); + + print("sVt: "); + print(sVt); + print("\n"); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(0); + // (MMA, MMA_N, MMA_K) + auto tSrK = make_fragment_B(thr_mma, sK); + print("tSrK: "); + print(tSrK); + print("\n"); + + auto tOrVt = make_fragment_B(thr_mma, sVt); + print("tOrVt: "); + print(tOrVt); + print("\n"); + + SmemTiledCopyK smem_tiled_copy_K; + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(0); + print(smem_thr_copy_K); + print("\n"); + + SmemTiledCopyVt smem_tiled_copy_Vt; + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(0); + print(smem_thr_copy_Vt); + print("\n"); + + // => ((_8,_1),_4,_4):((_1,_0),_8,_32) + auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + print("tSrK_copy_view: "); + print(tSrK_copy_view); + print("\n"); + + // => (((_4,_2),_1),_8,_2):(((_1,_32),_0),_4,_64) + auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); + print("tOrVt_copy_view: "); + print(tOrVt_copy_view); + print("\n"); +} + int main(int argc, char** argv) { // TODO: pass in as parameters - using Element = cute::half_t; + using DTYPE = cute::half_t; + using KV_DTYPE = cute::float_e4m3_t; + // using KV_DTYPE = cute::half_t; constexpr int kHeadDim = 64; constexpr int kBlockM = 64; @@ -124,8 +192,9 @@ int main(int argc, char** argv) { constexpr int kBlockK = 64; using Traits = - AttentionTraitsSM80; + AttentionTraitsSM80; print_attn_traits(); + // test_attn_traits(); return 0; } \ No newline at end of file