Skip to content

Commit

Permalink
Support AMD ROCm on FlashAttention 2 (Dao-AILab#1010)
Browse files Browse the repository at this point in the history
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <[email protected]>
Co-authored-by: Yichen Yan <[email protected]>
Co-authored-by: Po Yen Chen <[email protected]>
Co-authored-by: Yichen Yan <[email protected]>
  • Loading branch information
5 people authored Jul 23, 2024
1 parent dfe1a59 commit d8f104e
Show file tree
Hide file tree
Showing 11 changed files with 2,581 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs.

If you encounter bugs, please open a GitHub Issue!
## AMD GPU/ROCm Support
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.

## Installation and features
Requirements:
- ROCm 6.0+
- PyTorch 1.12.1+

We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.

To compile from source:
```sh
python setup.py install
```

FlashAttention-2 on ROCm currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.

## Tests
To run the tests:
```sh
pytest tests/test_flash_attn_ck.py
```

## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
Expand Down
1 change: 1 addition & 0 deletions csrc/composable_kernel
Submodule composable_kernel added at 818297
99 changes: 99 additions & 0 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"

std::vector<at::Tensor>
mha_fwd(at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &alibi_slopes_,
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
}
38 changes: 38 additions & 0 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif


#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")

namespace flash {
// Copy from PyTorch
// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
if (arg.captured_) {
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
// *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
// For most threads' reads it will hit in cache, so it shouldn't hurt performance.
return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}

} // namespace flash
Loading

0 comments on commit d8f104e

Please sign in to comment.