-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Multimodal][XPU]Enable vision attn backend for xpu platform #27525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully adds support for vision attention on XPU platforms by integrating the IPEX backend. The changes are logical and well-contained. However, there is a significant amount of duplicated code for the IPEX attention implementation in qwen2_vl.py and qwen2_5_vl.py. I've added comments suggesting a refactoring to improve code maintainability. Addressing this will make the codebase cleaner and easier to manage in the future.
| elif self.attn_backend == _Backend.IPEX: | ||
| from vllm._ipex_ops import ipex_ops | ||
|
|
||
| q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) | ||
|
|
||
| output = torch.empty(q.shape, dtype=q.dtype, device=q.device) | ||
| ipex_ops.varlen_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| output, | ||
| cu_seqlens, | ||
| cu_seqlens, | ||
| None, | ||
| max_seqlen, | ||
| max_seqlen, | ||
| pdropout=0.0, | ||
| softmax_scale=1.0 / (q.shape[-1] ** 0.5), | ||
| zero_tensors=False, | ||
| is_causal=False, | ||
| return_softmax=False, | ||
| gen_=None, | ||
| window_size_left=-1, | ||
| window_size_right=-1, | ||
| logits_soft_cap=-1, | ||
| ) | ||
| context_layer = rearrange( | ||
| output, "(b s) h d -> s b (h d)", b=batch_size | ||
| ).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for IPEX attention is identical to the one in vllm/model_executor/models/qwen2_vl.py. To improve maintainability and avoid code duplication, this logic should be refactored into a shared function.
For example, you could create a helper function in a common utility file (e.g., vllm/model_executor/models/vision.py):
from einops import rearrange
import torch
from vllm._ipex_ops import ipex_ops
def ipex_varlen_attention(q, k, v, cu_seqlens, max_seqlen, batch_size):
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = torch.empty(q.shape, dtype=q.dtype, device=q.device)
ipex_ops.varlen_attention(
q,
k,
v,
output,
cu_seqlens,
cu_seqlens,
None,
max_seqlen,
max_seqlen,
pdropout=0.0,
softmax_scale=1.0 / (q.shape[-1] ** 0.5),
zero_tensors=False,
is_causal=False,
return_softmax=False,
gen_=None,
window_size_left=-1,
window_size_right=-1,
logits_soft_cap=-1,
)
context_layer = rearrange(
output, "(b s) h d -> s b (h d)", b=batch_size
).contiguous()
return context_layerThen, you can call this function from both Qwen2_5_VisionAttention and Qwen2VisionAttention.
| elif self.attn_backend == _Backend.IPEX: | ||
| from vllm._ipex_ops import ipex_ops | ||
|
|
||
| q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) | ||
|
|
||
| output = torch.empty(q.shape, dtype=q.dtype, device=q.device) | ||
| ipex_ops.varlen_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| output, | ||
| cu_seqlens, | ||
| cu_seqlens, | ||
| None, | ||
| max_seqlen, | ||
| max_seqlen, | ||
| pdropout=0.0, | ||
| softmax_scale=1.0 / (q.shape[-1] ** 0.5), | ||
| zero_tensors=False, | ||
| is_causal=False, | ||
| return_softmax=False, | ||
| gen_=None, | ||
| window_size_left=-1, | ||
| window_size_right=-1, | ||
| logits_soft_cap=-1, | ||
| ) | ||
| context_layer = rearrange( | ||
| output, "(b s) h d -> s b (h d)", b=batch_size | ||
| ).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm/platforms/xpu.py
Outdated
| @classmethod | ||
| def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: | ||
| from vllm.attention.backends.registry import _Backend | ||
|
|
||
| return _Backend.IPEX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning IPEX backend breaks non‑Qwen ViT models on XPU
The new get_vit_attn_backend now always returns _Backend.IPEX. A number of existing vision models (e.g. GLM‑4V, Dots OCR, SigLIP2NaViT, Keye) guard their initialization with if self.attn_backend not in {FLASH_ATTN, TORCH_SDPA, XFORMERS, ROCM_AITER_FA}: raise RuntimeError(...) and were written assuming XPU would report TORCH_SDPA. With this change those models now receive _Backend.IPEX and immediately raise an unsupported backend error before any inference can run. Either keep returning TORCH_SDPA here or update every model’s whitelist to include _Backend.IPEX to avoid hard‑failing on XPU.
Useful? React with 👍 / 👎.
vllm/attention/layer.py
Outdated
| elif ( | ||
| self.attn_backend == _Backend.TORCH_SDPA | ||
| or self.attn_backend == _Backend.IPEX | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| elif ( | |
| self.attn_backend == _Backend.TORCH_SDPA | |
| or self.attn_backend == _Backend.IPEX | |
| ): | |
| elif self.attn_backend in (_Backend.TORCH_SDPA, _Backend.IPEX): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also handle IPEX backend in MultiHeadAttention?
|
This pull request has merge conflicts that must be resolved before it can be |
04011b1 to
6d3212b
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
@DarkLight1337 I did some changes, can you help review again? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ignore the other pre-commit errors
vllm/attention/layer.py
Outdated
| if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: | ||
| if attn_backend == _Backend.ROCM_AITER_FA: | ||
| from aiter import flash_attn_varlen_func | ||
| elif current_platform.is_xpu(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can avoid this elif branch, use fa_utils.flash_attn_varlen_func in else branch
| ) -> torch.Tensor: | ||
| if is_rocm_aiter: | ||
| from aiter import flash_attn_varlen_func | ||
| elif current_platform.is_xpu(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
| flash_attn_varlen_func = ops.flash_attn_varlen_func | ||
| else: | ||
| if use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from flash_attn import flash_attn_varlen_func | |
| from vllm.attention.utils.fa_utils import flash_attn_varlen_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand your point but should change the use_upstream_fa=False branch. Updated.
vllm/_ipex_ops.py
Outdated
| block_table: torch.Tensor | None = None, | ||
| alibi_slopes: torch.Tensor | None = None, | ||
| window_size: torch.Tensor | None = None, | ||
| softcap: torch.Tensor | None = 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
Signed-off-by: Yejing Lai <[email protected]> Signed-off-by: Yan Ma <[email protected]>
Signed-off-by: Guancheng Fu <[email protected]> Signed-off-by: Yan Ma <[email protected]>
Signed-off-by: Yan Ma <[email protected]>
…ort prefill only scenario Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Yan Ma <[email protected]>
Signed-off-by: Yan Ma <[email protected]>
Head branch was pushed to by a user without write access
…oject#27525) Signed-off-by: Yan Ma <[email protected]> Signed-off-by: Kunshang Ji <[email protected]> Co-authored-by: Yejing Lai <[email protected]> Co-authored-by: Guancheng Fu <[email protected]> Co-authored-by: Kunshang Ji <[email protected]>
Purpose
This PR uses
FLASH_ATTNas vision attention backend for xpu platform and actually callsvarlen_attentionkernel in IPEX by dispatching inflash_attn_varlen_func.Test Plan
python examples/offline_inference/vision_language.py -m glm-4vandpython examples/offline_inference/vision_language.py -m qwen2_5_vlTest Result
python examples/offline_inference/vision_language.py -m qwen2_5_vl: