diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py new file mode 100644 index 000000000..306b16239 --- /dev/null +++ b/hopper/benchmark_attn.py @@ -0,0 +1,273 @@ +from functools import partial +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func +from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3 + +# Need to install triton nightly: +# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + +try: + from triton_fused_attention import attention as triton_attention +except ImportError: + triton_attention = None + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_sdpa_setup(q, k, v, grad, causal=False): + b, nheads, seqlen_q, headdim = q.shape + _, _, seqlen_k, _ = k.shape + assert v.shape == (b, nheads, seqlen_k, headdim) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty_like(q_gpu) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph_forward = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q_forward = graph_forward.tensor_like(q_gpu.detach()) + k_forward = graph_forward.tensor_like(k_gpu.detach()) + v_forward = graph_forward.tensor_like(v_gpu.detach()) + + o_forward, stats_forward = graph_forward.sdpa( + name="sdpa", + q=q_forward, + k=k_forward, + v=v_forward, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + ) + + o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph_forward.validate() + graph_forward.build_operation_graph() + graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_forward.check_support() + graph_forward.build_plans() + + variant_pack_forward = { + q_forward: q_gpu, + k_forward: k_gpu, + v_forward: v_gpu, + o_forward: o_gpu, + stats_forward: stats_gpu, + } + + dQ_gpu = torch.empty_like(q_gpu) + dK_gpu = torch.empty_like(k_gpu) + dV_gpu = torch.empty_like(v_gpu) + dO_gpu = grad + + graph_backward = cudnn.pygraph( + io_data_type=cudnn.data_type.HALF, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_backward = graph_backward.tensor_like(q_gpu.detach()) + k_backward = graph_backward.tensor_like(k_gpu.detach()) + v_backward = graph_backward.tensor_like(v_gpu.detach()) + o_backward = graph_backward.tensor_like(o_gpu.detach()) + dO_backward = graph_backward.tensor_like(dO_gpu.detach()) + stats_backward = graph_backward.tensor_like(stats_gpu.detach()) + + dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward( + name="sdpa_backward", + q=q_backward, + k=k_backward, + v=v_backward, + o=o_backward, + dO=dO_backward, + stats=stats_backward, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + ) + + dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) + dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) + dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) + + graph_backward.validate() + graph_backward.build_operation_graph() + graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_backward.check_support() + graph_backward.build_plans() + + variant_pack_backward = { + q_backward: q_gpu, + k_backward: k_gpu, + v_backward: v_gpu, + o_backward: o_gpu, + dO_backward: dO_gpu, + stats_backward: stats_gpu, + dQ_backward: dQ_gpu, + dK_backward: dK_gpu, + dV_backward: dV_gpu, + } + + workspace = torch.empty( + max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()), + device="cuda", dtype=torch.uint8 + ) + + def run_fwd(*args, **kwargs): + graph_forward.execute(variant_pack_forward, workspace) + return o_gpu, stats_gpu + + def run_bwd(*args, **kwargs): + graph_backward.execute(variant_pack_backward, workspace) + return dQ_gpu, dK_gpu, dV_gpu + + return run_fwd, run_bwd + + +torch.manual_seed(0) +repeats = 100 +dropout_p = 0.0 +causal = False +dtype = torch.float16 +device = 'cuda' +verbose = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 + +# for mode in ['fwd', 'bwd']: +for mode in ['fwd']: + for headdim in [64, 128, 256]: + # for headdim in [128]: + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]: + # for seqlen in [8192]: + nheads = dim // headdim + # nheads = 24 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + nheads_kv = nheads + + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + q_t = q.transpose(1, 2).contiguous().detach().requires_grad_() + k_t = k.transpose(1, 2).contiguous().detach().requires_grad_() + v_t = k.transpose(1, 2).contiguous().detach().requires_grad_() + grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype) + grad_t = grad.transpose(1, 2).contiguous() + + bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad) + + for causal in [False, True]: + # for causal in [True]: + print(f"\n### {headdim = }, {seqlen = }, {causal = } ###") + if headdim <= 128 and cudnn is not None: + cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal) + f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode) + _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') + if mode == 'bwd': + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False) + if headdim <= 128: + if triton_attention is not None: + if mode == 'fwd': + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') + # TODO: fix Triton numeric errors. + # if mode == 'bwd': + # dv, v_t.grad = v_t.grad.clone(), None + # dk, k_t.grad = k_t.grad.clone(), None + # dq, q_t.grad = q_t.grad.clone(), None + # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + if cudnn is not None: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + if mode == 'fwd': + _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN') + else: + cudnn_sdpa_fwd() + _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + dq, dk, dv = cudnn_sdpa_bwd() + torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + # pytorch_profiler(cudnn_sdpa, backward=False) + if headdim == 128 or mode == 'fwd': + time.sleep(1) + _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3') + q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) + k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) + v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) + lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() + time.sleep(1) + _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') + if mode == 'bwd': + dv, v.grad = v.grad.clone(), None + dk, k.grad = k.grad.clone(), None + dq, q.grad = q.grad.clone(), None + torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05) + + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False) + print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS') + if headdim <= 128: + if triton_attention is not None: + print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS') + if headdim == 128 or mode == 'fwd': + print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS') + print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS') + \ No newline at end of file diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 2d5c33eb4..852343860 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -17,20 +17,15 @@ namespace flash { using namespace cute; // template -template +template struct CollectiveEpilogueFwd { using Element = typename Ktraits::Element; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kHeadDim = Ktraits::kHeadDim; - // using Element = Element_; - // static constexpr int kBlockM = kBlockM_; - // static constexpr int kBlockN = kBlockN_; - // static constexpr int kHeadDim = kHeadDim_; using TileShape_MNK = Shape, Int, Int>; - // static constexpr int kNWarps = kNWarps_; static constexpr int kNWarps = Ktraits::kNWarps; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr bool Is_WS = kNWarps >= 12; @@ -38,20 +33,6 @@ struct CollectiveEpilogueFwd { static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; - using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; - - // These are for storing the output tensor without TMA (e.g., for setting output to zero) - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad; - static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>>{})); // Val layout, 8 or 16 vals per store - using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); @@ -59,52 +40,72 @@ struct CollectiveEpilogueFwd { using SmemCopyAtomO = Copy_Atom; using SharedStorage = cute::array_aligned>; - using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) - using StrideO = cute::Stride; - using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch) - + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; using TMA_O = decltype(make_tma_copy( GmemTiledCopyOTMA{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}), + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + typename Seqlen_traits::ShapeT{}, + typename Seqlen_traits::StrideT{} + ), SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{})); // no mcast for O + // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len) + static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); + static_assert(kHeadDim % kNumVecElem == 0); + static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; + static_assert(NumMmaThreads % kNumThreadsPerRow == 0); + static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; + using TiledCopyOAtom = cute::Copy_Atom, Element>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), + LayoutRight{})); + using TiledCopyOValLayout = decltype(cute::make_layout( + cute::make_shape(_1{}, Int{}), + LayoutRight{})); + using TiledCopyO = decltype(make_tiled_copy( + TiledCopyOAtom{}, + TiledCopyOThrLayout{}, // Thr layout + TiledCopyOValLayout{} // Val layout + )); + // Host side kernel arguments struct Arguments { Element* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; + typename Seqlen_traits::LayoutT const layout_O; float* ptr_LSE; - StrideLSE const stride_LSE; + typename Seqlen_traits::LayoutLseT const layout_LSE; }; // Device side kernel params struct Params { Element* ptr_O; - ShapeO const shape_O; - StrideO const stride_O; + typename Seqlen_traits::LayoutT const layout_O; float* ptr_LSE; - StrideLSE const stride_LSE; + typename Seqlen_traits::LayoutLseT const layout_LSE; TMA_O tma_store_O; }; static Params to_underlying_arguments(Arguments const& args) { - Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O); TMA_O tma_store_O = make_tma_copy( GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast for O - return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O}; + return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& epilogue_params) { - cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); + if constexpr (!Seqlen_traits::kUseVarSeqLen) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); + } } template @@ -115,7 +116,8 @@ struct CollectiveEpilogueFwd { SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, - cute::tuple const& block_coord + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q ) { auto [m_block, bidh, bidb] = block_coord; @@ -134,16 +136,9 @@ struct CollectiveEpilogueFwd { cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O); - Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); - Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) - Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) - - auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); - Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); - + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb)(_, m_block); Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) @@ -156,19 +151,23 @@ struct CollectiveEpilogueFwd { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); - if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); } + if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); } } } - if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, - cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - int const lane_predicate = cute::elect_one_sync(); - if (lane_predicate) { - cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); - tma_store_arrive(); - } + int write_warp_idx = kNWarps - 1; + if (cutlass::canonical_warp_idx_sync() == write_warp_idx) { + cutlass::arch::NamedBarrier::sync( + NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ); } + TiledCopyO gmem_tiled_copy_O; + flash::write_O( + epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, + epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, + m_block, bidh, bidb, seqlen_traits_q, write_warp_idx + ); } CUTLASS_DEVICE void @@ -177,20 +176,25 @@ struct CollectiveEpilogueFwd { } // Write 0 to output and -inf to LSE + template CUTLASS_DEVICE void store_zero( - Params const& epilogue_params, - int thread_idx, - cute::tuple const& block_coord - ) { + Params const& epilogue_params, + SharedStorage& shared_storage, + int thread_idx, + cute::tuple const& block_coord, + const Seqlen_traits& seqlen_traits_q + ) { auto [m_block, bidh, bidb] = block_coord; - Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O); - Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); - Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); - - GmemTiledCopyO gmem_tiled_copy_O; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O); + Tensor gO = seqlen_traits_q.get_local_tile_tensor( + mO, select<0, 2>(TileShape_MNK{}), bidh, bidb + )(_, _, m_block); // (M, K) + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE); + Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor( + mLSE, Shape>{}, bidh, bidb)(_, m_block); + + TiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); @@ -201,13 +205,13 @@ struct CollectiveEpilogueFwd { Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM ); static_assert(kBlockM <= NumMmaThreads); - if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } + if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } } }; diff --git a/hopper/flash.h b/hopper/flash.h index 7c61f3daa..fa5a8caff 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -57,7 +57,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k; // The scaling factors for the kernel. float scale_softmax; @@ -128,6 +128,8 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + int * __restrict__ tile_count_semaphore; }; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index f21d2d12e..397ed4cc3 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool seqlenq_ngroups_swapped=false) { + bool seqlenq_ngroups_swapped=false, + bool unpadded_lse=false) { // Reset the parameters params = {}; @@ -81,6 +82,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.cu_seqlens_k = static_cast(cu_seqlens_k_d); params.seqused_k = static_cast(seqused_k); + TORCH_CHECK( + bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k), + "cu_seqlens_q and cu_seqlens_k must be both null or non-null" + ); + // P = softmax(QK^T) params.p_ptr = p_d; @@ -139,6 +145,8 @@ void set_params_fprop(Flash_fwd_params ¶ms, #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif + + params.unpadded_lse = unpadded_lse; } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -372,6 +380,154 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; } +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + bool is_causal) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + int window_size_left = -1; + int window_size_right = -1; + if (is_causal) { window_size_right = 0; } + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + const int total_q = q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + /*p_d=*/nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + /*seqlenq_ngroups_swapped=*/false, + /*unpadded_lse=*/true); + params.total_q = total_q; + params.total_k = total_k; + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse}; +} + void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // FP16_SWITCH(!params.is_bf16, [&] { // HEADDIM_SWITCH(params.d, [&] { @@ -577,4 +733,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &mha_fwd, "Forward pass"); m.def("bwd", &mha_bwd, "Backward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index c09342826..d88ab78ea 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -57,6 +57,83 @@ def _flash_attn_backward( ) return dq, dk, dv, softmax_d +def _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, +): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd( + q, + k, + v, + None, + cu_seqlens_q, + cu_seqlens_k, + None, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse + + +def _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, +): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = _get_fa_module().varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + class FlashAttnFunc(torch.autograd.Function): @staticmethod @@ -105,6 +182,71 @@ def backward(ctx, dout, *args): return dq, dk, dv, None, None +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal=causal, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + ) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out, softmax_lse + + @staticmethod + def backward(ctx, dout, *args): + # TODO: Uncomment these when var-seq-len is supported in bwd kernel. + # q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + # dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + # _flash_attn_varlen_backward( + # dout, + # q, + # k, + # v, + # out, + # softmax_lse, + # dq, + # dk, + # dv, + # cu_seqlens_q, + # cu_seqlens_k, + # ctx.max_seqlen_q, + # ctx.max_seqlen_k, + # ctx.softmax_scale, + # ctx.causal, + # ) + # dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + # dk = dk[..., : dout.shape[-1]] + # dv = dv[..., : dout.shape[-1]] + # return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + def flash_attn_func( q, k, @@ -167,3 +309,62 @@ def flash_attn_func( softmax_scale, causal, ) + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=False, +): + """ + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + ) diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index b97250d65..8bcfab6dc 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -24,11 +24,12 @@ namespace flash { using namespace cute; -template +template __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) - compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, - CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, - CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params, + Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k ) { using Element = typename Ktraits::Element; @@ -46,8 +47,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, // static constexpr int kBlockN = Ktraits::kBlockN; // constexpr int kHeadDim = Ktraits::kHeadDim; - using CollectiveMainloop = CollectiveMainloopFwd; - using CollectiveEpilogue = CollectiveEpilogueFwd; + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -115,14 +116,21 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, auto block_coord = work_tile_info.get_block_coord(scheduler_params); auto [m_block, bidh, bidb] = block_coord; - int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + seqlen_traits_q.init(bidb); + seqlen_traits_k.init(bidb); + if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { + continue; + } + int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); if (Is_causal && n_block_max <= 0) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); continue; } collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, - shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx, + seqlen_traits_q, seqlen_traits_k); ++work_idx; } collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); @@ -154,17 +162,24 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, auto block_coord = work_tile_info.get_block_coord(scheduler_params); auto [m_block, bidh, bidb] = block_coord; - int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + seqlen_traits_q.init(bidb); + seqlen_traits_k.init(bidb); + if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) { + continue; + } + int n_block_max = collective_mainloop.get_n_block_max( + mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. - collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); + collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); continue; } collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, - tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage); + tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage, + seqlen_traits_q, seqlen_traits_k); // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - NumCopyThreads, block_coord); + threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); ++work_idx; } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index c58c21dea..cd7adb3bf 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -14,41 +14,61 @@ #include "tile_scheduler.hpp" #include "flash_fwd_kernel.h" #include "kernel_traits.h" +#include "seq_len.h" #include "utils.h" -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using Element = typename Kernel_traits::Element; using TileShape_MNK = typename Kernel_traits::TileShape_MNK; using ClusterShape = typename Kernel_traits::ClusterShape_MNK; // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); - using CollectiveMainloop = flash::CollectiveMainloopFwd; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; - using Scheduler = std::conditional_t>; - // flash::SingleTileScheduler>; + using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using Scheduler = std::conditional_t< + Seqlen_traits::kUseVarSeqLen, + flash::SingleTileScheduler, + std::conditional_t + >>; + // using Scheduler = flash::SingleTileScheduler; + Seqlen_traits seqlen_traits_q( + params.total_q, params.seqlen_q, params.cu_seqlens_q); + Seqlen_traits seqlen_traits_k( + params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ static_cast(params.q_ptr), - {params.seqlen_q, params.d, params.h, params.b}, // shape_Q - {params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q + seqlen_traits_q.get_gmem_layout( + params.seqlen_q, params.d, params.h, params.b, + params.q_row_stride, params.q_head_stride, params.q_batch_stride + ), // layout_Q static_cast(params.k_ptr), - {params.seqlen_k, params.d, params.h_k, params.b}, // shape_K - {params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K + seqlen_traits_k.get_gmem_layout( + params.seqlen_k, params.d, params.h_k, params.b, + params.k_row_stride, params.k_head_stride, params.k_batch_stride + ), // layout_K static_cast(params.v_ptr), - {params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V + seqlen_traits_k.get_gmem_layout( + params.seqlen_k, params.d, params.h_k, params.b, + params.v_row_stride, params.v_head_stride, params.v_batch_stride + ), // layout_V params.scale_softmax_log2 }); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments({ static_cast(params.o_ptr), - {params.seqlen_q, params.d, params.h, params.b}, // shape_O - {params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O + seqlen_traits_q.get_gmem_layout( + params.seqlen_q, params.d, params.h, params.b, + params.o_row_stride, params.o_head_stride, params.o_batch_stride + ), // layout_O static_cast(params.softmax_lse_ptr), - {_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE + seqlen_traits_q.get_lse_gmem_layout( + params.seqlen_q, params.h, params.b + ) // layout_LSE }); int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); @@ -58,7 +78,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // Get the ptr to kernel function. void *kernel; - kernel = (void *)flash::compute_attn_ws; + kernel = (void *)flash::compute_attn_ws; int smem_size = sizeof(typename Kernel_traits::SharedStorage); // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); @@ -81,7 +101,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { dim3 block_dims(ctaSize); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params); + cutlass::launch_kernel_on_cluster( + launch_params, kernel, mainloop_params, epilogue_params, + scheduler_params, seqlen_traits_q, seqlen_traits_k); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -89,7 +111,12 @@ template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Seqlen_traits + >(params, stream); + }); }); } @@ -97,9 +124,14 @@ template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { - run_flash_fwd, Is_causal>(params, stream); + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Seqlen_traits + >(params, stream); + }); }); }); } @@ -108,9 +140,14 @@ template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // Only use Cluster if number of tiles along seqlen_q is even - BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { - run_flash_fwd, Is_causal>(params, stream); + SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal, Seqlen_traits + >(params, stream); + }); }); }); } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index f9dc94a23..2de15fb9c 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -21,7 +21,7 @@ namespace flash { using namespace cute; -template +template struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; @@ -64,19 +64,24 @@ struct CollectiveMainloopFwd { // decltype(tile_to_shape(SmemLayoutAtomVTMA{}, // make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) - using StrideQKV = cute::Stride; - using TMA_Q = decltype(make_tma_copy( GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), + typename Seqlen_traits::StrideT{} + ), SmemLayoutQ{}, select<0, 2>(TileShape_MNK{}), _1{})); // no mcast for Q using TMA_KV = decltype(make_tma_copy( GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), + typename Seqlen_traits::StrideT{} + ), take<0, 2>(SmemLayoutK{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any @@ -95,20 +100,19 @@ struct CollectiveMainloopFwd { // Host side kernel arguments struct Arguments { Element const* ptr_Q; - ShapeQKV const shape_Q; - StrideQKV const stride_Q; + typename Seqlen_traits::LayoutT layout_Q; Element const* ptr_K; - ShapeQKV const shape_K; - StrideQKV const stride_K; + typename Seqlen_traits::LayoutT layout_K; Element const* ptr_V; - StrideQKV const stride_V; + typename Seqlen_traits::LayoutT layout_V; float const softmax_scale_log2; }; // Device side kernel params struct Params { - ShapeQKV const shape_Q; - ShapeQKV const shape_K; + typename Seqlen_traits::LayoutT layout_Q; + typename Seqlen_traits::LayoutT layout_K; + typename Seqlen_traits::LayoutT layout_V; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_KV tma_load_K, tma_load_V; @@ -118,29 +122,29 @@ struct CollectiveMainloopFwd { static Params to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q); TMA_Q tma_load_Q = make_tma_copy( GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast for Q - Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); TMA_KV tma_load_K = make_tma_copy( GmemTiledCopyKV{}, mK, SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); TMA_KV tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, SmemLayoutV{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {args.shape_Q, args.shape_K, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + return {args.layout_Q, args.layout_K, args.layout_V, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), tma_load_Q, tma_load_K, tma_load_V, args.softmax_scale_log2}; } @@ -154,11 +158,15 @@ struct CollectiveMainloopFwd { } CUTLASS_DEVICE - int get_n_block_max(Params const& mainloop_params, int m_block) { + int get_n_block_max( + Params const& mainloop_params, int m_block, + const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k + ) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_q = get<0>(mainloop_params.shape_Q); - int const seqlen_k = get<0>(mainloop_params.shape_K); + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; int n_block_max = cute::ceil_div(seqlen_k, kBlockN); if constexpr (Is_causal) { n_block_max = std::min(n_block_max, @@ -179,16 +187,18 @@ struct CollectiveMainloopFwd { typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, cute::tuple block_coord, - int work_idx + int work_idx, + const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k ) { Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); - Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q); - Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K); + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); auto [m_block, bidh, bidb] = block_coord; int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); @@ -197,9 +207,12 @@ struct CollectiveMainloopFwd { uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gQ = seqlen_traits_q.get_local_tile_tensor( + mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K) + Tensor gK = seqlen_traits_k.get_local_tile_tensor( + mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) + Tensor gV = seqlen_traits_k.get_local_tile_tensor( + mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); @@ -218,7 +231,7 @@ struct CollectiveMainloopFwd { } } - int n_block_max = get_n_block_max(mainloop_params, m_block); + int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); int n_block = n_block_max - 1; int lane_predicate = cute::elect_one_sync(); @@ -331,7 +344,9 @@ struct CollectiveMainloopFwd { int thread_idx, int work_idx, int m_block, - SharedStorage& shared_storage + SharedStorage& shared_storage, + const Seqlen_traits& seqlen_traits_q, + const Seqlen_traits& seqlen_traits_k ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); @@ -360,8 +375,8 @@ struct CollectiveMainloopFwd { }; tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = get<0>(mainloop_params.shape_Q); - int const seqlen_k = get<0>(mainloop_params.shape_K); + int const seqlen_q = seqlen_traits_q.actual_seq_len; + int const seqlen_k = seqlen_traits_k.actual_seq_len; int n_block = n_block_count - 1; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); @@ -483,4 +498,3 @@ struct CollectiveMainloopFwd { }; } // namespace flash - diff --git a/hopper/seq_len.h b/hopper/seq_len.h new file mode 100644 index 000000000..76c4d08a3 --- /dev/null +++ b/hopper/seq_len.h @@ -0,0 +1,168 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace flash { + +static constexpr int kMaxTileSize = 128; + +template class SeqLenTraits { +public: + // Total number of queries / keys. Unpadded. + int sum_s = 0; + // seq len offsets. + int *cu_seq_len = nullptr; + // actual seq len array. + int *seq_used = nullptr; + // seq len of the current batch. + int actual_seq_len = -1; + + // Whether this is for fixed-seq-len or var-seq-len. + static constexpr bool kUseVarSeqLen = UseVarSeqLen; + + using ShapeT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using StrideT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using LayoutT = cute::Layout; + + using ShapeLseT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using StrideLseT = std::conditional_t< + UseVarSeqLen, + cute::Shape, + cute::Shape + >; + using LayoutLseT = cute::Layout; + + CUTLASS_HOST SeqLenTraits() {} + + CUTLASS_HOST SeqLenTraits( + int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr): + sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {} + + // Returns the layout of a tensor in MKHB format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. + CUTLASS_HOST_DEVICE auto get_gmem_layout( + int m, int k, int h, int b, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded = false) const { + static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + return make_layout(make_shape(m, k, h, b), + make_stride(m_stride, cute::_1{}, h_stride, b_stride)); + } + + // Returns the layout of a tensor in MKHB format in global memory. + // padded: only useful for var-seq-len for dq_accum and softmax_d. + CUTLASS_HOST_DEVICE auto get_lse_gmem_layout( + int m, int h, int b, bool padded = false) const { + static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + return make_layout(make_shape(b, h, m), + make_stride(int64_t(h * m), int64_t(m), cute::_1())); + } + + CUTLASS_DEVICE void init(int bidb) {} + + template + CUTLASS_DEVICE auto get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded = false) const { + auto g_tensor = local_tile( + m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{})); + return g_tensor; + } + + template + CUTLASS_DEVICE auto get_lse_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded = false) const { + auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_)); + return g_tensor; + } +}; + +using FixedSeqLenTraits = SeqLenTraits; + +using VarSeqLenTraits = SeqLenTraits; + +// Returns the static layout of a var-seq-len tensor in global memory based on +// max_seq_len and max_batch_size. +// padded: only useful for var-seq-len for dq_accum and softmax_d. +// When padded is True, use B_M + kMaxTileSize * B as the total B_M. +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout( + int m, int k, int h, int b, + int64_t m_stride, int64_t h_stride, int64_t b_stride, + bool padded) const { + return make_layout( + make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h), + make_stride(m_stride, cute::_1{}, h_stride)); +} + +// padded: only useful for var-seq-len for dq_accum and softmax_d. +// When padded is True, use B_M + kMaxTileSize * B as the total B_M. +template <> +CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout( + int m, int h, int b, bool padded) const { + return make_layout( + make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)), + make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1())); +} + +template <> +CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) { + actual_seq_len = + seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]); +} + +template <> +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded) const { + auto g_offset = local_tile( + m_tensor(_, _, bidh), + cute::make_shape(1, get<1>(tile_shape)), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{})); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout( + cute::make_shape(actual_seq_len, get<1>(tile_shape)), + g_offset.stride() + )); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; +} + +template <> +template +CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh, int bidb, bool padded) const { + auto g_offset = local_tile( + m_tensor(bidh, _), cute::make_shape(_1{}), + make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0))); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{}))); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_)); + return g_tensor; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/hopper/static_switch.h b/hopper/static_switch.h index e870643e7..d9ec62224 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -66,18 +66,14 @@ } \ }() -#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...) \ +#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, NAME, ...) \ [&] { \ - if (!USE_VAR_SEQ_LEN) { \ - if (SEQ_LEN_OUT_OF_BOUND_CHECK) { \ - using kSeqLenTraitsType = FixedSeqLenTraits; \ - return __VA_ARGS__(); \ - } else { \ - using kSeqLenTraitsType = FixedSeqLenTraits; \ - return __VA_ARGS__(); \ - } \ + bool useSeqLen = USE_VAR_SEQ_LEN; \ + if (useSeqLen) { \ + using NAME = flash::VarSeqLenTraits; \ + return __VA_ARGS__(); \ } else { \ - using kSeqLenTraitsType = VarSeqLenTraits; \ + using NAME = flash::FixedSeqLenTraits; \ return __VA_ARGS__(); \ } \ }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 97852d47e..55ec48686 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -5,40 +5,12 @@ import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn_interface import flash_attn_func +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func +from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref ABS_TOL = 5e-3 REL_TOL = 1e-1 -def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, -): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = ( - seqlen_k - if key_padding_mask is None - else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - ) - sq = ( - seqlen_q - if query_padding_mask is None - else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - ) - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - def print_diffs(out, out_ref): out_1d = out.flatten() out_ref_1d = out_ref.flatten() @@ -51,86 +23,6 @@ def print_diffs(out, out_ref): print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}") -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - attn_bias=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - upcast=True, - reorder_ops=False, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) - else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - (-1, 0), - None, - None, - q.device, - ) - scores.masked_fill_(local_mask, float("-inf")) - if attn_bias is not None: - scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - # Some rows might be completely masked out so we fill them with zero instead of NaN - if causal: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) - dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - else: - attention_drop = attention - output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @@ -142,10 +34,11 @@ def attention_ref( # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [64, 128, 256]) -# @pytest.mark.parametrize("d", [256]) +# @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ + (257, 1), (64, 128), (128, 128), (256, 256), @@ -175,8 +68,9 @@ def test_flash_attn_output( batch_size = 9 nheads = 6 nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) - # batch_size = 1 - # nheads = 1 + # nheads_kv = 2 + # batch_size = 9 + # nheads = 6 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) @@ -244,9 +138,172 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. + # breakpoint() assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() # if d <= 128: # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() # assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() # assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize("d", [64, 128, 256]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (113, 203), + (128, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (512, 256), + (640, 128), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, causal, mha_type, dtype +): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 1 + # nheads = 1 + batch_size = 9 + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True + ) + + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("q_unpad, shape: ", q_unpad.shape) + # print("k_unpad, shape: ", k_unpad.shape) + # print("v_unpad, shape: ", v_unpad.shape) + out_unpad, sm_lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal=causal, + ) + out = output_pad_fn(out_unpad) + dropout_mask = None + + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + causal=causal, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # g = torch.randn_like(out) + # if d <= 128: + # ( + # dq_unpad, + # dk_unpad, + # dv_unpad, + # ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + # dk = dk_pad_fn(dk_unpad) + # dv = dk_pad_fn(dv_unpad) + # ( + # dq_ref, + # dk_ref, + # dv_ref, + # ) = torch.autograd.grad(out_ref, (q, k, v), g) + # ( + # dq_pt, + # dk_pt, + # dv_pt, + # ) = torch.autograd.grad(out_pt, (q, k, v), g) + # dq = dq_pad_fn(dq_unpad) + # print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + # print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + # print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + # print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + # print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + # print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + # print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + # print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + # print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + # print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + # print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + # print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + # if d <= 128: + # assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + # assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() diff --git a/hopper/utils.h b/hopper/utils.h index 85392d5e8..90116f8a7 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -15,6 +15,7 @@ #endif #include +#include #include #include @@ -228,4 +229,93 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor +__forceinline__ __device__ void write_tma( + ElemO* O, const TMACopyO& tma_store_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, + const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { + Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape()); + Tensor gO = seqlen_traits_o.get_local_tile_tensor( + mO, tile_shape_O, bidh, bidb + )(_, _, m_block); // (M, K) + auto block_tma_O = tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == write_warp_idx && lane_predicate) { + cute::copy(tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + // Note: no wait here. + // tma_store_wait<0>(); +} + +template +__forceinline__ __device__ void write_tiled( + ElemO* O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, + const SeqLenTraits& seqlen_traits_o) { + Tensor mO = make_tensor(make_gmem_ptr(O), layout_O); + Tensor gO = seqlen_traits_o.get_local_tile_tensor( + mO, tile_shape_O, bidh, bidb + )(_, _, m_block); // (M, K) + + ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads); + Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k) + Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + + // Prepare for TiledCopy. + // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst. + // After grouping, the first dim is number of elements to read together. + Tensor tOsOFlatten = cute::flatten(tOsO); + Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten); + Tensor tOgOFlatten = cute::flatten(tOgO); + Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten); + + // Get thread coords to global index mapping. + Tensor gOCounting = cute::make_identity_tensor(gO.shape()); + Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting); + Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting); + Tensor tSgOCountingGrouped = + cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten); + + // Write out to GMEM. + const int kNumMsPerTile = get<0>(tile_shape_O); + int cta_m = std::min( + seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile + ); + if (cta_m == kNumMsPerTile) { + copy(tiled_copy_O, tOsOGroup, tOgOGroup); + } else { + auto predicate_fn = [&](auto coords) { + auto s_coords = tSgOCountingGrouped(_0{}, coords); + return elem_less(get<0>(s_coords), cta_m); + }; + copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup); + } +} + +template +__forceinline__ __device__ void write_O( + ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O, + const LayoutO& layout_O, const TileShapeO& tile_shape_O, + const SMemO& sO, int m_block, int bidh, int bidb, + const SeqLenTraits& seqlen_traits_o, int write_warp_idx) { + if constexpr (IsTMACopy) { + write_tma(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx); + } else { + write_tiled(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 000000000..513a9b8e8 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,254 @@ +import math + +import torch +from einops import rearrange, repeat +from flash_attn.bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores /= softcap + scores = scores.tanh() + scores *= softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)