Skip to content

Commit

Permalink
Merge pull request #1442 from houseroad/replace_c10_optional
Browse files Browse the repository at this point in the history
Replace c10::optional with std::optional in flash_attn
  • Loading branch information
ipiszy authored Jan 14, 2025
2 parents 22c0358 + 74aed78 commit 0fcd405
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 166 deletions.
64 changes: 32 additions & 32 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params,
return std::make_tuple(softmax_lse_accum, out_accum);
}

void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
#ifdef FLASHATTENTION_DISABLE_ALIBI
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
params.alibi_slopes_ptr = nullptr;
Expand All @@ -349,16 +349,16 @@ std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
std::optional<at::Generator> gen_) {

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -513,13 +513,13 @@ std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand All @@ -530,7 +530,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
std::optional<at::Generator> gen_) {

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -769,19 +769,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -975,12 +975,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
Expand All @@ -991,8 +991,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -1201,16 +1201,16 @@ std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &seqlens_k_, // batch_size
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
Expand Down
62 changes: 31 additions & 31 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,28 @@ std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand All @@ -40,7 +40,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);
std::optional<at::Generator> gen_);

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
Expand All @@ -49,19 +49,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
Expand All @@ -70,12 +70,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
Expand All @@ -86,23 +86,23 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);

std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &seqlens_k_, // batch_size
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
Expand Down
Loading

0 comments on commit 0fcd405

Please sign in to comment.