forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support AMD ROCm on FlashAttention 2 (Dao-AILab#1010)
* Support ck in fmha * Add ck submodule * Do not return lse if return_softmax == false * Use receipt to speed up ck compile time * Integrate new version of ck_tile * Support dropout for mha_fwd() * Add dropout to mha_varlen_fwd() * Update ck to develop * Extract padding function for dropout randval * Extract randval transformation function * Sync the code structure and coding style with FA * Remove this line, c++ api will handle this. Sync with test_flash_attn.py * fix compile error * Add mha_bwd * Generate dropout seed and offset from user generator * update CK * Add mha_varlen_bwd * Use same python as build flash-attn to generate ck kernel * Fix bug of group mode fwd about returning softmax lse * larger the test tollerance * Add test_flash_attn_output() and test_flash_attn_varlen_output() * Always fill softmax_lse * Remove duplicate benchmark script, since we already implement mha_bwd * Refine get value from tuple * Use default parameter for stream_config * unblock all platform * Add comment * refine the test code * Refine naming * Add unpack to namespace * Do not hardcode the warp size 64 * Add more targets * Add README * Optimize mha_fwd if seqlen_q == 1 * Support get_wheel_url for rocm * Detect rocm environment by pytorch's IS_HIP_EXTENSION * update to lastest ck * Add necessary compile flag * Sync the api with upstream FA --------- Co-authored-by: carlushuang <[email protected]> Co-authored-by: Yichen Yan <[email protected]> Co-authored-by: Po Yen Chen <[email protected]> Co-authored-by: Yichen Yan <[email protected]>
- Loading branch information
1 parent
dfe1a59
commit d8f104e
Showing
11 changed files
with
2,581 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[submodule "csrc/cutlass"] | ||
path = csrc/cutlass | ||
url = https://github.com/NVIDIA/cutlass.git | ||
[submodule "csrc/composable_kernel"] | ||
path = csrc/composable_kernel | ||
url = https://github.com/ROCm/composable_kernel.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule composable_kernel
added at
818297
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
|
||
#include "flash_common.hpp" | ||
|
||
std::vector<at::Tensor> | ||
mha_fwd(at::Tensor &q, | ||
const at::Tensor &k, | ||
const at::Tensor &v, | ||
c10::optional<at::Tensor> &out_, | ||
c10::optional<at::Tensor> &alibi_slopes_, | ||
const float p_dropout, | ||
const float softmax_scale, | ||
bool is_causal, | ||
int window_size_left, | ||
int window_size_right, | ||
const float softcap, | ||
const bool return_softmax, | ||
c10::optional<at::Generator> gen_); | ||
|
||
std::vector<at::Tensor> | ||
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i | ||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. | ||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. | ||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i | ||
const at::Tensor &cu_seqlens_q, // b+1 | ||
const at::Tensor &cu_seqlens_k, // b+1 | ||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. | ||
c10::optional<const at::Tensor> &leftpad_k_, // batch_size | ||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq | ||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads | ||
int max_seqlen_q, | ||
const int max_seqlen_k, | ||
const float p_dropout, | ||
const float softmax_scale, | ||
const bool zero_tensors, | ||
bool is_causal, | ||
int window_size_left, | ||
int window_size_right, | ||
const float softcap, | ||
const bool return_softmax, | ||
c10::optional<at::Generator> gen_); | ||
|
||
std::vector<at::Tensor> | ||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og | ||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size | ||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size | ||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size | ||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size | ||
const at::Tensor &softmax_lse, // b x h x seqlen_q | ||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size | ||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size | ||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size | ||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads | ||
const float p_dropout, // probability to drop | ||
const float softmax_scale, | ||
const bool is_causal, | ||
int window_size_left, | ||
int window_size_right, | ||
const float softcap, | ||
const bool deterministic, | ||
c10::optional<at::Generator> gen_, | ||
c10::optional<at::Tensor> &rng_state); | ||
|
||
std::vector<at::Tensor> | ||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size | ||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i | ||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i | ||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i | ||
const at::Tensor &out, // total_q x num_heads x head_size | ||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp | ||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i | ||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i | ||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i | ||
const at::Tensor &cu_seqlens_q, // b+1 | ||
const at::Tensor &cu_seqlens_k, // b+1 | ||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads | ||
const int max_seqlen_q, | ||
const int max_seqlen_k, // max sequence length to choose the kernel | ||
const float p_dropout, // probability to drop | ||
const float softmax_scale, | ||
const bool zero_tensors, | ||
const bool is_causal, | ||
int window_size_left, | ||
int window_size_right, | ||
const float softcap, | ||
const bool deterministic, | ||
c10::optional<at::Generator> gen_, | ||
c10::optional<at::Tensor> &rng_state); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.doc() = "FlashAttention"; | ||
m.def("fwd", &mha_fwd, "Forward pass"); | ||
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); | ||
m.def("bwd", &mha_bwd, "Backward pass"); | ||
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. | ||
#include <torch/python.h> | ||
#include <torch/nn/functional.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#ifdef OLD_GENERATOR_PATH | ||
#include <ATen/CUDAGeneratorImpl.h> | ||
#else | ||
#include <ATen/cuda/CUDAGeneratorImpl.h> | ||
#endif | ||
|
||
|
||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") | ||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") | ||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
|
||
namespace flash { | ||
// Copy from PyTorch | ||
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 | ||
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) { | ||
if (arg.captured_) { | ||
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". | ||
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. | ||
// For most threads' reads it will hit in cache, so it shouldn't hurt performance. | ||
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_)); | ||
} else { | ||
return std::make_tuple(arg.seed_.val, arg.offset_.val); | ||
} | ||
} | ||
|
||
} // namespace flash |
Oops, something went wrong.