diff --git a/.gitmodules b/.gitmodules index 8d501cb19..6216182e7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "csrc/composable_kernel"] + path = csrc/composable_kernel + url = https://github.com/ROCm/composable_kernel.git diff --git a/README.md b/README.md index 6a836c9f3..b6efd8ee3 100644 --- a/README.md +++ b/README.md @@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style models, mostly on A100 GPUs. If you encounter bugs, please open a GitHub Issue! +## AMD GPU/ROCm Support +ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2. + +## Installation and features +Requirements: +- ROCm 6.0+ +- PyTorch 1.12.1+ + +We recommend the +[Pytorch](https://hub.docker.com/r/rocm/pytorch) +container from ROCm, which has all the required tools to install FlashAttention. + +To compile from source: +```sh +python setup.py install +``` + +FlashAttention-2 on ROCm currently supports: +1. MI200 or MI300 GPUs. +2. Datatype fp16 and bf16 +3. Forward's head dimensions up to 256. Backward head dimensions up to 128. + +## Tests +To run the tests: +```sh +pytest tests/test_flash_attn_ck.py +``` ## Citation If you use this codebase, or otherwise found our work valuable, please cite: diff --git a/csrc/composable_kernel b/csrc/composable_kernel new file mode 160000 index 000000000..8182976c3 --- /dev/null +++ b/csrc/composable_kernel @@ -0,0 +1 @@ +Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1 diff --git a/csrc/flash_attn_ck/flash_api.cpp b/csrc/flash_attn_ck/flash_api.cpp new file mode 100644 index 000000000..0c7474b97 --- /dev/null +++ b/csrc/flash_attn_ck/flash_api.cpp @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common.hpp" + +std::vector +mha_fwd(at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + c10::optional &out_, + c10::optional &alibi_slopes_, + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + c10::optional gen_); + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &leftpad_k_, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool return_softmax, + c10::optional gen_); + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state); + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); + m.def("bwd", &mha_bwd, "Backward pass"); + m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); +} diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp new file mode 100644 index 000000000..cc601f970 --- /dev/null +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -0,0 +1,38 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +namespace flash { +// Copy from PyTorch +// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 +static std::tuple unpack(at::PhiloxCudaState arg) { + if (arg.captured_) { + // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". + // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. + // For most threads' reads it will hit in cache, so it shouldn't hurt performance. + return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace flash diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp new file mode 100644 index 000000000..884215adf --- /dev/null +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -0,0 +1,379 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common.hpp" + +#include "fmha_bwd.hpp" +#include "mask.hpp" + +fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout}; +} + +fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (batch_size, seqlen_q, nheads, hdim) + // k: (batch_size, seqlen_k, nheads_k, hdim) + // v: (batch_size, seqlen_k, nheads_k, hdim) + // o: (batch_size, seqlen_q, nheads, hdim) + // dq: (batch_size, seqlen_q, nheads, hdim) + // dk_expanded: (batch_size, seqlen_k, nheads, hdim) + // dv_expanded: (batch_size, seqlen_k, nheads, hdim) + // do: (batch_size, seqlen_q, nheads, hdim) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, seqlen_q) + // d: (batch_size, nheads, seqlen_q) + + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t stride_do = dout.stride(1); + ck_tile::index_t stride_dk = dk.stride(1); + ck_tile::index_t stride_dv = dv.stride(1); + + ck_tile::index_t nhead_stride_q = q.stride(2); + ck_tile::index_t nhead_stride_k = k.stride(2); + ck_tile::index_t nhead_stride_v = v.stride(2); + ck_tile::index_t nhead_stride_o = out.stride(2); + ck_tile::index_t nhead_stride_do = dout.stride(2); + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); + + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t batch_stride_o = out.stride(0); + ck_tile::index_t batch_stride_do = dout.stride(0); + ck_tile::index_t batch_stride_lse = softmax_lse.stride(0); + ck_tile::index_t batch_stride_dk = dk.stride(0); + ck_tile::index_t batch_stride_dv = dv.stride(0); + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, // seqlen_k_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + false, // s_randval + {drop_seed, drop_offset}}; +} + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, // probability to drop + const float softmax_scale, + const bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentHIPStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); // unpadded hdim + const int head_size_8x = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + // TODO - CK does not support dq_accum + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + + if (rng_state.has_value()) { + uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); + drop_seed = d[0]; + drop_offset = d[1]; + } else if(is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + if (seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + + auto traits = + get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_bwd_args( + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_bwd(traits, args, stream_config); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} \ No newline at end of file diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp new file mode 100644 index 000000000..c1eeba507 --- /dev/null +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -0,0 +1,348 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common.hpp" + +#include "fmha_fwd.hpp" +#include "mask.hpp" + +fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (batch_size, seqlen_q, nheads, d) + // k: (batch_size, seqlen_k, nheads_k, d) + // v: (batch_size, seqlen_k, nheads_k, d) + // o: (batch_size, seqlen_q, nheads, d) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, seqlen_q) + // randval: (batch_size, nheads, seqlen_q, seqlen_k) + + ck_tile::index_t stride_q = q.stride(1); + ck_tile::index_t stride_k = k.stride(1); + ck_tile::index_t stride_v = v.stride(1); + ck_tile::index_t stride_o = out.stride(1); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(2); + ck_tile::index_t nhead_stride_k = k.stride(2); + ck_tile::index_t nhead_stride_v = v.stride(2); + ck_tile::index_t nhead_stride_o = out.stride(2); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t batch_stride_q = q.stride(0); + ck_tile::index_t batch_stride_k = k.stride(0); + ck_tile::index_t batch_stride_v = v.stride(0); + ck_tile::index_t batch_stride_o = out.stride(0); + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + nullptr, // lse_acc + nullptr, // o_acc + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + nullptr, // seqstart_q + nullptr, // seqstart_k + nullptr, + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + 1, // num_splits + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + 0, // stride_o_acc, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + 0, // nhead_stride_lse_acc + 0, // nhead_stride_o_acc + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + 0, // batch_stride_lse_acc + 0, // batch_stride_o_acc + batch_stride_o, + 0, // split_stride_lse_acc + 0, // split_stride_o_acc + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + {drop_seed, drop_offset}}; +} + +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + + mask_info mask; + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local + } + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + seqlen_q = ngroups; + num_heads = num_heads_k; + } + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } + else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + } + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } + else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8)); + } + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + rng_state[0] = *(reinterpret_cast(&drop_seed)); + rng_state[1] = *(reinterpret_cast(&drop_offset)); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_fwd(traits, args, stream_config); + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + if (seqlenq_ngroups_swapped) { + out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + } + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp new file mode 100644 index 000000000..d8eabab15 --- /dev/null +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -0,0 +1,406 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common.hpp" + +#include "fmha_bwd.hpp" +#include "mask.hpp" + +fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_alibi) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + false, // has_dbias + has_dropout}; +} + +fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int max_seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (total_q, nheads, hdim) + // k: (total_k, nheads_k, hdim) + // v: (total_k, nheads_k, hdim) + // o: (total_q, nheads, hdim) + // dq: (total_q, nheads, hdim) + // dk_expanded: (total_k, nheads, hdim) + // dv_expanded: (total_k, nheads, hdim) + // do: (total_q, nheads, hdim) + + // alibi_slopes:(batch_size, nheads) or (nhead) + // lse: (batch_size, nheads, max_seqlen_q) + // d: (batch_size, nheads, max_seqlen_q) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t stride_do = dout.stride(0); + ck_tile::index_t stride_dk = dk.stride(0); + ck_tile::index_t stride_dv = dv.stride(0); + + ck_tile::index_t nhead_stride_q = q.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(1); + ck_tile::index_t nhead_stride_do = dout.stride(1); + ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); + + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t batch_stride_o = 0; + ck_tile::index_t batch_stride_do = 0; + ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);; + ck_tile::index_t batch_stride_dk = 0; + ck_tile::index_t batch_stride_dv = 0; + + float p_undrop = 1.0 - p_dropout; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_bwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_k_ptr + total_q, + total_k, + b, + max_seqlen_q, // max_seqlen_q + max_seqlen_k, // max_seqlen_k + hdim, // hdim_q + hdim, // hdim_v + h, // nhead + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dk, + stride_dv, + 0, // stride_dbias, FA without bias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + 0, // nhead_stride_dbias, FA without dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0 , // batch_stride_bias, FA without bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + batch_stride_dk, + batch_stride_dv, + 0 , // batch_stride_dbias, FA without dbias + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + p_undrop, + false, // s_randval + {drop_seed, drop_offset}}; +} + +std::vector +mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads + const int max_seqlen_q, + const int max_seqlen_k, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool deterministic, + c10::optional gen_, + c10::optional &rng_state) +{ +#ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); +#endif + if (is_causal) { window_size_right = 0; } + + bool is_dropout = p_dropout > 0.0; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size_8x = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + if (is_causal) { + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + // q, k, v, out had been padded in mha_fwd + // dq_, dk_, dv_ are also padded tensor + CHECK_SHAPE(q, total_q, num_heads, head_size_8x); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x); + CHECK_SHAPE(out, total_q, num_heads, head_size_8x); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size_8x); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + // TODO - CK does not support dq_accum + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts); + dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + if(zero_tensors) { + dq.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + + if (rng_state.has_value()) { + uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); + drop_seed = d[0]; + drop_offset = d[1]; + } else if(is_dropout) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + if (max_seqlen_q > 0) { + ck_tile::stream_config stream_config{stream}; + dq.zero_(); // ck use atomic operation on dq + + auto traits = + get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_bwd_args( + mask, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size_8x, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + dout_padded, + softmax_d, + dq, + dk_expanded, + dv_expanded, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_bwd(traits, args, stream_config); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + // For MQA/GQA we need to sum dK and dV across the groups + if (num_heads_k != num_heads) { + at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2}); + } + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} \ No newline at end of file diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp new file mode 100644 index 000000000..2d2f4cfef --- /dev/null +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -0,0 +1,371 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include "flash_common.hpp" + +#include "fmha_fwd.hpp" +#include "mask.hpp" + +fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool has_lse, + bool enable_alibi) +{ + return fmha_fwd_traits{head_size, + head_size, + dtype, + true, // is_group_mode + true, // is_v_rowmajor + mask.type, + enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + has_dropout, + false}; // do_fp8_static_quant +} + +fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + // sizes + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + c10::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + uint64_t drop_seed, + uint64_t drop_offset) +{ + // q: (total_q, nheads, d) + // k: (total_k, nheads_k, d) + // v: (total_k, nheads_k, d) + // o: (total_q, nheads, d) + + // alibi_slopes:(batch, nheads) or (nhead) + // lse: (batch, nheads, max_seqlen_q) + // randval: (nheads, total_q, max_seqlen_k) + + ck_tile::index_t total_q = q.size(0); + ck_tile::index_t total_k = k.size(0); + + ck_tile::index_t stride_q = q.stride(0); + ck_tile::index_t stride_k = k.stride(0); + ck_tile::index_t stride_v = v.stride(0); + ck_tile::index_t stride_o = out.stride(0); + ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; + + ck_tile::index_t nhead_stride_q = q.stride(1); + ck_tile::index_t nhead_stride_k = k.stride(1); + ck_tile::index_t nhead_stride_v = v.stride(1); + ck_tile::index_t nhead_stride_o = out.stride(1); + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; + + ck_tile::index_t batch_stride_q = 0; + ck_tile::index_t batch_stride_k = 0; + ck_tile::index_t batch_stride_v = 0; + ck_tile::index_t batch_stride_o = 0; + + ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_randval = 0; + + void *alibi_slopes_ptr = nullptr; + ck_tile::index_t stride_alibi_slopes = 0; + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); + alibi_slopes_ptr = alibi_slopes.data_ptr(); + stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } + + return fmha_fwd_args{q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + nullptr, // lse_acc + nullptr, // o_acc + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q + seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_kpads + total_q, + total_k, + b, + max_seqlen_q, + d, // hdim_q + d, // hdim_v + h, // nhead + h_k, // nhead_k + 1, // num_splits + softmax_scale, // scale_s + 1, // scale_p + 1, // scale_o + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + 0, // stride_o_acc, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias, FA without bias + nhead_stride_randval, + nhead_stride_lse, + 0, // nhead_stride_lse_acc + 0, // nhead_stride_o_acc + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias, FA without bias + batch_stride_randval, + batch_stride_lse, + 0, // batch_stride_lse_acc + 0, // batch_stride_o_acc + batch_stride_o, + 0, // split_stride_lse_acc + 0, // split_stride_o_acc + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, + has_dropout_randval, + {drop_seed, drop_offset}}; +} + +std::vector +mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + c10::optional & /*seqused_k*/, + c10::optional &/*leftpad_k_*/, // batch_size + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, + const int max_seqlen_k, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + int window_size_left, + int window_size_right, + const float /*softcap*/, + const bool return_dropout_randval, + c10::optional gen_) +{ + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + // TODO - Support paged_KV + const bool paged_KV = block_table_.has_value(); + TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = k.size(1); + + const int max_num_blocks_per_seq = 0; + const int num_blocks = 0; + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + + // TODO + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + + const int total_q = q.size(0); + const int total_k = k.size(0); + + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + mask_info mask; + + if (is_causal) { + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + window_size_right = 0; + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual + } + else if (window_size_left == -1 && window_size_right == -1) { + mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask + } + else { + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); + mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local + } + + CHECK_SHAPE(q, total_q, num_heads, head_size_og); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } + else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } + else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_8x = round_multiple(head_size_og, 8); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + bool has_lse = true; + bool has_dropout = p_dropout > 0.0f; + + at::Tensor softmax_lse; + // TODO - check gradient, only training require lse + softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32)); + + at::Tensor p; + if (return_dropout_randval) { + TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); + p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8)); + } + + if (zero_tensors) + { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_dropout_randval) {p.zero_();} + } + + uint64_t drop_seed = 1, drop_offset = 0; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + + if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + } + + rng_state[0] = *(reinterpret_cast(&drop_seed)); + rng_state[1] = *(reinterpret_cast(&drop_offset)); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + ck_tile::stream_config stream_config{stream}; + + auto traits = + get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value()); + + auto args = + get_ck_fmha_varlen_fwd_args( + has_lse, + return_dropout_randval, + mask, + batch_size, + max_seqlen_q, + num_heads, + num_heads_k, + head_size_8x, + q_padded, + k_padded, + v_padded, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes_, + out, + softmax_lse, + p, + softmax_scale, + p_dropout, + drop_seed, + drop_offset); + + fmha_fwd(traits, args, stream_config); + } + else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; +} diff --git a/setup.py b/setup.py index 1f7ae655f..1051a1ff6 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,8 @@ import os import re import ast +import glob +import shutil from pathlib import Path from packaging.version import parse, Version import platform @@ -22,6 +24,8 @@ CppExtension, CUDAExtension, CUDA_HOME, + ROCM_HOME, + IS_HIP_EXTENSION, ) @@ -32,6 +36,19 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") + +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_ROCM = True + else: + IS_ROCM = False +else: + if BUILD_TARGET == "cuda": + IS_ROCM = False + elif BUILD_TARGET == "rocm": + IS_ROCM = True + PACKAGE_NAME = "flash_attn" BASE_WHEEL_URL = ( @@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None: ) +def check_if_rocm_home_none(global_option: str) -> None: + if ROCM_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but hipcc was not found." + ) + + def append_nvcc_threads(nvcc_extra_args): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] +def rename_cpp_to_cu(cpp_files): + for entry in cpp_files: + shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") + + +def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] + + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" + + cmdclass = {} ext_modules = [] # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) +if IS_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) +else: + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) -if not SKIP_CUDA_BUILD: +if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) @@ -250,6 +295,95 @@ def append_nvcc_threads(nvcc_extra_args): ], ) ) +elif not SKIP_CUDA_BUILD and IS_ROCM: + ck_dir = "csrc/composable_kernel" + + #use codegen get code dispatch + if not os.path.exists("./build"): + os.makedirs("build") + + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") + os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + check_if_rocm_home_none("flash_attn") + cc_flag = [] + + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + sources = ["csrc/flash_attn_ck/flash_api.cpp", + "csrc/flash_attn_ck/mha_bwd.cpp", + "csrc/flash_attn_ck/mha_fwd.cpp", + "csrc/flash_attn_ck/mha_varlen_bwd.cpp", + "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( + f"build/fmha_*wd*.cpp" + ) + + rename_cpp_to_cu(sources) + + renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", + "csrc/flash_attn_ck/mha_bwd.cu", + "csrc/flash_attn_ck/mha_fwd.cu", + "csrc/flash_attn_ck/mha_varlen_bwd.cu", + "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": + [ + "-O3","-std=c++17", + "-mllvm", "-enable-post-misched=0", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + "-DCK_ENABLE_BF16", + "-DCK_ENABLE_BF8", + "-DCK_ENABLE_FP16", + "-DCK_ENABLE_FP32", + "-DCK_ENABLE_FP64", + "-DCK_ENABLE_FP8", + "-DCK_ENABLE_INT8", + "-DCK_USE_XDL", + "-DUSE_PROF_API=1", + "-D__HIP_PLATFORM_HCC__=1", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + ] + + generator_flag + + cc_flag + , + } + + include_dirs = [ + Path(this_dir) / "csrc" / "composable_kernel" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", + Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + ] + + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=renamed_sources, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) + ) def get_package_version(): @@ -264,25 +398,33 @@ def get_package_version(): def get_wheel_url(): - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build torch, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - torch_cuda_version = parse(torch.version.cuda) torch_version_raw = parse(torch.__version__) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 - # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() - # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + if IS_ROCM: + torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) + hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" + wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + else: + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) + return wheel_url, wheel_filename diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py new file mode 100644 index 000000000..fbcb51cef --- /dev/null +++ b/tests/test_flash_attn_ck.py @@ -0,0 +1,754 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from flash_attn import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, +) + +from test_flash_attn import ( + attn_bias_from_alibi_slopes, + convert_flash_attn_S_to_softmax, + generate_qkv, + generate_random_padding_mask, + attention_ref, + attention_kvpacked_ref, + attention_qkvpacked_ref, +) + +def is_bwd_hdim_supported(d): + return d <= 128 and d % 2 == 0 + + +def ck_randval_to_dropout_mask(randval, p): + # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout + # randval in 255 * [0, 0.7] will be kept + # If return dropout_mask >=0, value will be kept + return torch.floor(255.0 * (1 - p) - randval) + + +def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded): + """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded] + Arguments: + S_dmask: (nheads, total_q, max_seqlen_k) + cu_seqlens_q: (b + 1) + Output: + S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded) + """ + batch_size = cu_seqlens_q.numel() - 1 + seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q + seqlens_q = seqlens_q[0:batch_size].tolist() + S_dmask = torch.split(S_dmask, seqlens_q, dim=1) + # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)] + masks = () + for mask in S_dmask: + # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded) + mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1) + masks = masks + (mask, ) + S_dmask = torch.cat(masks, dim=1) + + S_dmask = S_dmask.transpose(0, 1) + return S_dmask + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): + if d > 256: + pytest.skip() + + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) + + qkv = torch.randn( + batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, S_dmask = flash_attn_qkvpacked_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + # TODO - move to c++ mha_varlen_fwd() + S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen, + seqlen, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + # CK does not return P. Hence, we don't test the attn here. + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref( + qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_qkvpacked_ref( + qkv, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + g = torch.randn_like(out) + if is_bwd_hdim_supported(d): + (dqkv,) = torch.autograd.grad(out, qkv, g) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + + # TODO - use 10 times to check, wait for ck to change dq type to f32 + assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) +@pytest.mark.parametrize("dropout_p", [0, 0.17]) +def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): + if d > 256: + pytest.skip() + + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + nheads = 6 + window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) + qkv = torch.randn( + batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + + key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") + # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes( + alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal + ) + else: + alibi_slopes, attn_bias = None, None + + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True + ) + + out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( + qkv_unpad, + cu_seqlens, + max_seqlen, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: + # TODO - move to c++ mha_varlen_fwd() + S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) + S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen) + + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen, + seqlen, + key_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + + dropout_mask = S_dmask_converted >= 0 + # CK does not return P. Hence, we don't test the attn here. + else: + dropout_mask = None + + out_ref, attn_ref = attention_qkvpacked_ref( + qkv, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_qkvpacked_ref( + qkv, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + g = torch.randn_like(out) + if is_bwd_hdim_supported(d): + (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) + dqkv = dqkv_pad_fn(dqkv_unpad) + (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) + (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) + print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") + print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") + print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") + print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") + + # TODO - use 10 times to check, wait for ck to change dq type to f32 + assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item() + + +@pytest.mark.parametrize("kvpacked", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked +): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + out, lse, S_dmask = flash_attn_kvpacked_func( + q, + kv, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + if dropout_p > 0.0: + # TODO - move to c++ mha_varlen_fwd() + S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + None, + None, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + # CK does not return P. Hence, we don't test the attn here. + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + g = torch.randn_like(out) + if is_bwd_hdim_supported(d): + if kvpacked: + ( + dq, + dkv, + ) = torch.autograd.grad(out, (q, kv), g) + dk, dv = dkv.unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # TODO - use 10 times to check, wait for ck to change dq type to f32 + assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + + +@pytest.mark.parametrize("kvpacked", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("alibi", [False, True]) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 147), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked +): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if kvpacked: + kv = torch.randn( + batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + else: + k = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True + ) + + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes( + alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal + ) + else: + alibi_slopes, attn_bias = None, None + + if kvpacked: + ( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + kv, + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if dropout_p > 0.0: + # TODO - move to c++ mha_varlen_fwd() + S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) + S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k) + S_dmask_converted = convert_flash_attn_S_to_softmax( + S_dmask, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + d, + dropout_p > 0.0, + causal=causal, + window_size=window_size, + ) + dropout_mask = S_dmask_converted >= 0 + if kvpacked: + kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) + k_rep, v_rep = kv_rep.unbind(dim=2) + else: + k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) + # CK does not return P. Hence, we don't test the attn here. + else: + dropout_mask = None + + if kvpacked: + out_ref, attn_ref = attention_kvpacked_ref( + q, + kv, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_kvpacked_ref( + q, + kv, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + else: + out_ref, attn_ref = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + attn_bias, + dropout_p, + dropout_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most 4 times the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + + g = torch.randn_like(out) + if is_bwd_hdim_supported(d): + if kvpacked: + ( + dq_unpad, + dkv_unpad, + ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) + dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) + ( + dq_ref, + dkv_ref, + ) = torch.autograd.grad(out_ref, (q, kv), g) + dk_ref, dv_ref = dkv_ref.unbind(2) + ( + dq_pt, + dkv_pt, + ) = torch.autograd.grad(out_pt, (q, kv), g) + dk_pt, dv_pt = dkv_pt.unbind(2) + else: + ( + dq_unpad, + dk_unpad, + dv_unpad, + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + dq = dq_pad_fn(dq_unpad) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # TODO - use 10 times to check, wait for ck to change dq type to f32 + assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()