Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,21 +270,23 @@ def reshape_and_cache_flash(

@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: torch.Tensor | None,
softmax_scale: float | None = None,
causal: bool = False,
out: torch.Tensor | None = None,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
Expand All @@ -295,31 +297,63 @@ def flash_attn_varlen_func(
num_splits=0,
s_aux: torch.Tensor | None = None,
):
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)

if block_table is None:
assert cu_seqlens_k is not None, (
"cu_seqlens_k can't be None when calling varlen_attention."
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ipex_ops.varlen_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
out,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
False,
None,
real_window_size[0],
real_window_size[1],
-1,
)
return out
else:
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
sink=s_aux,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)

@staticmethod
def get_scheduler_metadata(
Expand Down
35 changes: 18 additions & 17 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend(
):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, (
"XPU platform only supports FLASH_ATTN as vision attention backend."
)
use_upstream_fa = False
else:
return _Backend.TORCH_SDPA, None

Expand All @@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand your point but should change the use_upstream_fa=False branch. Updated.

else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None

Expand Down Expand Up @@ -521,22 +526,18 @@ def __init__(
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False

if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)

self.attn_backend, self._flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def __init__(

if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
Expand Down Expand Up @@ -856,10 +858,7 @@ def compute_attn_mask_seqlen(
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
Expand Down
5 changes: 1 addition & 4 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,7 @@ def compute_attn_mask_seqlen(
self, cu_seqlens: torch.Tensor
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if (
self.attn_backend == _Backend.FLASH_ATTN
or self.attn_backend == _Backend.ROCM_AITER_FA
):
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory

@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend

return _Backend.FLASH_ATTN

@classmethod
def inference_mode(cls):
return torch.no_grad()
Expand Down