diff --git a/csrc/cutlass b/csrc/cutlass index 756c351b4..4c42f73fd 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b +Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index a74de974a..94352e5ac 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -1221,7 +1221,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index 5a7b74911..8c0897488 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - using SmemCopyAtomOaccum = Copy_Atom; + using SmemCopyAtomO = Copy_Atom, Element>; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store @@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; @@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load }; @@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - using SmemCopyAtomPdS = Copy_Atom; + using SmemCopyAtomPdS = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); @@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdKV = Copy_Atom; + using SmemCopyAtomdKV = Copy_Atom, elem_type>; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, @@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); - using SmemCopyAtomdQ = Copy_Atom; + using SmemCopyAtomdQ = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< @@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom, ElementAccum>{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store