Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from flag_gems.ops.argmax import argmax
from flag_gems.ops.argmin import argmin
from flag_gems.ops.attention import (
cascade_attention,
flash_attention_forward,
flash_attn_varlen_func,
merge_attn_states,
scaled_dot_product_attention,
)
from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
Expand Down Expand Up @@ -289,6 +291,8 @@
"fill_tensor_",
"flash_attention_forward",
"flash_attn_varlen_func",
"cascade_attention",
"merge_attn_states",
"flip",
"floor_divide",
"floor_divide_",
Expand Down
221 changes: 221 additions & 0 deletions src/flag_gems/ops/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from functools import partial
from typing import Optional

import torch
import triton
Expand Down Expand Up @@ -661,3 +662,223 @@ def flash_attn_varlen_func(
)

return (out, softmax_lse) if return_softmax_lse else out


@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)

p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)

# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse

max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = tl.exp(p_lse) + tl.exp(s_lse)

if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)

head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ head_arange,
mask=head_mask,
)

# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask,
)


# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)

# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)


# def merge_attn_states(
# output: torch.Tensor,
# prefix_output: torch.Tensor,
# prefix_lse: torch.Tensor,
# suffix_output: torch.Tensor,
# suffix_lse: torch.Tensor,
# output_lse: Optional[torch.Tensor] = None,
# ) -> None:

# # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# # is not support for FP8 dtype, fallback to use Triton kernel.
# def supported_dtypes(o: torch.Tensor) -> bool:
# return o.dtype in [torch.float32, torch.half, torch.bfloat16]

# # NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# # kernel load/store 128b(16 bytes) per memory issue within
# # thread. Namely, the headsize(headdim) must be multiple of
# # pack_size (float32 -> 4, half/bfloat16 -> 8).
# def supported_headdim(o: torch.Tensor) -> bool:
# headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
# if o.dtype == torch.float32:
# return headdim % 4 == 0
# return headdim % 8 == 0

# if (current_platform.is_cuda() and supported_dtypes(output)
# and supported_headdim(output)):
# from vllm._custom_ops import merge_attn_states
# return merge_attn_states(output, prefix_output, prefix_lse,
# suffix_output, suffix_lse, output_lse)
# else:
# from vllm.attention.ops.triton_merge_attn_states import (
# merge_attn_states)
# return merge_attn_states(output, prefix_output, prefix_lse,
# suffix_output, suffix_lse, output_lse)


def cascade_attention(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_query_lens: torch.Tensor,
max_query_len: int,
cu_prefix_query_lens: torch.Tensor,
prefix_kv_lens: torch.Tensor,
suffix_kv_lens: torch.Tensor,
max_kv_len: int,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
sliding_window: tuple[int, int],
logits_soft_cap: float,
block_table: torch.Tensor,
common_prefix_len: int,
fa_version: int,
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window.
assert sliding_window == (
-1,
-1,
), "Cascade attention does not support sliding window."

num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])

# Process shared prefix.
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
)

descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

# Process suffix per query.
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
)

# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
83 changes: 42 additions & 41 deletions src/flag_gems/patches/patch_vllm_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def custom_gems_flash_mla_forward(
return o


def custom_gems_flash_attention_impl_forwad(
def custom_gems_flash_attention_impl_forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
Expand All @@ -150,7 +150,11 @@ def custom_gems_flash_attention_impl_forwad(
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash
from flag_gems import (
cascade_attention,
flash_attn_varlen_func,
reshape_and_cache_flash,
)

assert output is not None, "Output tensor must be provided."

Expand Down Expand Up @@ -203,16 +207,16 @@ def custom_gems_flash_attention_impl_forwad(
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
# scheduler_metadata = local_metadata.local_scheduler_metadata
scheduler_metadata = local_metadata.local_scheduler_metadata
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# scheduler_metadata = attn_metadata.scheduler_metadata
scheduler_metadata = attn_metadata.scheduler_metadata

# descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

flash_attn_varlen_func(
q=query[:num_actual_tokens],
Expand All @@ -229,44 +233,41 @@ def custom_gems_flash_attention_impl_forwad(
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
# scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
scheduler_metadata=scheduler_metadata,
fa_version=2,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

# TODO: Support cascade_attention.
raise NotImplementedError("Cascade attention is not implemented in flag_gems.")

# assert not use_local_attn, "Cascade attention does not support local attention."
# # Cascade attention (rare case).
# cascade_attention(
# output[:num_actual_tokens],
# query[:num_actual_tokens],
# key_cache,
# value_cache,
# cu_query_lens=attn_metadata.query_start_loc,
# max_query_len=attn_metadata.max_query_len,
# cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
# prefix_kv_lens=attn_metadata.prefix_kv_lens,
# suffix_kv_lens=attn_metadata.suffix_kv_lens,
# max_kv_len=attn_metadata.max_seq_len,
# softmax_scale=self.scale,
# alibi_slopes=self.alibi_slopes,
# sliding_window=self.sliding_window,
# logits_soft_cap=self.logits_soft_cap,
# block_table=attn_metadata.block_table,
# common_prefix_len=attn_metadata.common_prefix_len,
# fa_version=self.vllm_flash_attn_version,
# prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
# suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
# )
# return output
assert not use_local_attn, "Cascade attention does not support local attention."
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output


def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor):
Expand Down Expand Up @@ -380,7 +381,7 @@ def apply_gems_patches_to_vllm(verbose=True):
TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward, verbose
)
patch_module_method(
FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forwad, verbose
FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward, verbose
)
patch_vllm_lib("_C", "silu_and_mul", custom_silu_and_mul, "CUDA", verbose)
patch_vllm_lib(
Expand Down
Loading
Loading