Skip to content

Commit

Permalink
FA3 initial code release
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 11, 2024
1 parent b4a9dd6 commit 7f67966
Show file tree
Hide file tree
Showing 26 changed files with 6,996 additions and 0 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,42 @@ contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.


## FlashAttention-3 beta release
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).

Blogpost: https://tridao.me/blog/2024/flash3/

Paper: https://tridao.me/publications/flash3/flash3.pdf

![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)

This is a beta release for testing / benchmarking before we integrate that with
the rest of the repo.

Currently released:
- FP16 forward and backward

Coming soon in the next couple of days / next week:
- BF16
- Variable length (FP16, BF16)
- FP8 forward.

Requirements: H100 / H800 GPU, CUDA >= 12.3.

To install:
```sh
cd hopper
python setup.py install
```
To run the test:
```sh
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
```



## Installation and features

Requirements:
Expand Down
1 change: 1 addition & 0 deletions hopper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "3.0.0.b1"
46 changes: 46 additions & 0 deletions hopper/block_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool Varlen=true>
struct BlockInfo {

template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}

template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
214 changes: 214 additions & 0 deletions hopper/epilogue_fwd_sm90_tma.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/

#pragma once

#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"

#include "cutlass/gemm/collective/collective_builder.hpp"

#include "utils.h"

namespace flash {

using namespace cute;

// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
template <typename Ktraits>
struct CollectiveEpilogueFwd {

using Element = typename Ktraits::Element;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim;
// using Element = Element_;
// static constexpr int kBlockM = kBlockM_;
// static constexpr int kBlockN = kBlockN_;
// static constexpr int kHeadDim = kHeadDim_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;

// static constexpr int kNWarps = kNWarps_;
static constexpr int kNWarps = Ktraits::kNWarps;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr bool Is_WS = kNWarps >= 12;

static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;

using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;

// These are for storing the output tensor without TMA (e.g., for setting output to zero)
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store

using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));

using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;

using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch)

using TMA_O = decltype(make_tma_copy(
GmemTiledCopyOTMA{},
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}),
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for O

// Host side kernel arguments
struct Arguments {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
};

// Device side kernel params
struct Params {
Element* ptr_O;
ShapeO const shape_O;
StrideO const stride_O;
float* ptr_LSE;
StrideLSE const stride_LSE;
TMA_O tma_store_O;
};

static Params
to_underlying_arguments(Arguments const& args) {
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
TMA_O tma_store_O = make_tma_copy(
GmemTiledCopyOTMA{},
mO,
SmemLayoutO{},
select<0, 2>(TileShape_MNK{}),
_1{}); // no mcast for O
return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O};
}

/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& epilogue_params) {
cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
}

template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
CUTLASS_DEVICE void
store(Params const& epilogue_params,
FrgTensorO const& tOrO,
FrgTensorLSE const& lse,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {

auto [m_block, bidh, bidb] = block_coord;
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);

Tensor tOrO_out = flash::convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)

// Make sure all WGs have finished reading V
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);

Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)

auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));

Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
// taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(_0{})) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}

if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
int const lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
tma_store_arrive();
}
}
}

CUTLASS_DEVICE void
store_tail() {
tma_store_wait<0>();
}

// Write 0 to output and -inf to LSE
CUTLASS_DEVICE void
store_zero(
Params const& epilogue_params,
int thread_idx,
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
) {
auto [m_block, bidh, bidb] = block_coord;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O);
Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));

GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_fragment_like(tOgO);
clear(tOrO);
// Construct identity layout for sO
Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM
);
static_assert(kBlockM <= NumMmaThreads);
if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
}

};

} // namespace flash
Loading

0 comments on commit 7f67966

Please sign in to comment.