diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 60ee0124c3d9..95c17cb331f6 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -270,21 +270,23 @@ def reshape_and_cache_flash( @staticmethod def flash_attn_varlen_func( - out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, - seqused_k: torch.Tensor, # we don't support this in ipex kernel max_seqlen_q: int, max_seqlen_k: int, - softmax_scale: float, - causal: bool, - block_table: torch.Tensor, - alibi_slopes: torch.Tensor | None, + softmax_scale: float | None = None, + causal: bool = False, + out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + alibi_slopes: torch.Tensor | None = None, window_size: list[int] | None = None, softcap: float | None = 0.0, + seqused_k: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, + # passed in qwen vl + dropout_p: float = 0.0, # The following parameters are not used in ipex kernel currently, # we keep API compatible to CUDA's. scheduler_metadata=None, @@ -295,31 +297,63 @@ def flash_attn_varlen_func( num_splits=0, s_aux: torch.Tensor | None = None, ): + if out is None: + out = torch.empty(q.shape, dtype=q.dtype, device=q.device) real_window_size: tuple[int, int] if window_size is None: real_window_size = (-1, -1) else: assert len(window_size) == 2 real_window_size = (window_size[0], window_size[1]) - return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( - out, - q.contiguous(), - k, - v, - cu_seqlens_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - block_table, - alibi_slopes, - softcap=softcap, - window_size_left=real_window_size[0], - window_size_right=real_window_size[1], - k_scale=1.0, - v_scale=1.0, - ) + + if block_table is None: + assert cu_seqlens_k is not None, ( + "cu_seqlens_k can't be None when calling varlen_attention." + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + ipex_ops.varlen_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + out, + cu_seqlens_q, + cu_seqlens_k, + None, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale, + False, + causal, + False, + None, + real_window_size[0], + real_window_size[1], + -1, + ) + return out + else: + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + q.contiguous(), + k, + v, + cu_seqlens_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + block_table, + alibi_slopes, + sink=s_aux, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], + k_scale=1.0, + v_scale=1.0, + ) @staticmethod def get_scheduler_metadata( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 22eaa22b8b38..17e025155a43 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -123,6 +123,11 @@ def maybe_get_vit_flash_attn_backend( ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True + elif current_platform.is_xpu(): + assert attn_backend == _Backend.FLASH_ATTN, ( + "XPU platform only supports FLASH_ATTN as vision attention backend." + ) + use_upstream_fa = False else: return _Backend.TORCH_SDPA, None @@ -133,7 +138,7 @@ def maybe_get_vit_flash_attn_backend( if use_upstream_fa: from flash_attn import flash_attn_varlen_func else: - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.attention.utils.fa_utils import flash_attn_varlen_func else: flash_attn_varlen_func = None @@ -521,22 +526,18 @@ def __init__( # If vllm native fa is selected, we use it directly. use_upstream_fa = False - if current_platform.is_xpu(): - # currently, only torch_sdpa is supported on xpu - self.attn_backend = _Backend.TORCH_SDPA - else: - self.attn_backend = ( - backend - if backend - in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, - } - else _Backend.TORCH_SDPA - ) + self.attn_backend = ( + backend + if backend + in { + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.PALLAS, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + } + else _Backend.TORCH_SDPA + ) self.attn_backend, self._flash_attn_varlen_func = ( maybe_get_vit_flash_attn_backend( diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index f71f49a1a31b..6cefe7441668 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -70,7 +70,7 @@ def flash_attn_maxseqlen_wrapper( if use_upstream_fa: from flash_attn import flash_attn_varlen_func else: - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.attention.utils.fa_utils import flash_attn_varlen_func q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3d67653726bd..3585783e4ccc 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -364,6 +364,8 @@ def __init__( if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: self.use_upstream_fa = True + if current_platform.is_xpu(): + self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA, @@ -856,10 +858,7 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA - ): + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f0d7e2e7d7ec..a81acf9f9a36 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -789,10 +789,7 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA - ): + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index cd65cba6b492..07ab759e4baa 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -115,6 +115,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + from vllm.attention.backends.registry import _Backend + + return _Backend.FLASH_ATTN + @classmethod def inference_mode(cls): return torch.no_grad()