Skip to content

Commit f17de7d

Browse files
committed
use FLASH_ATTN as xpu vision backend
Signed-off-by: Yan Ma <[email protected]>
1 parent 42ba6fb commit f17de7d

File tree

6 files changed

+67
-77
lines changed

6 files changed

+67
-77
lines changed

vllm/_ipex_ops.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -270,21 +270,23 @@ def reshape_and_cache_flash(
270270

271271
@staticmethod
272272
def flash_attn_varlen_func(
273-
out: torch.Tensor,
274273
q: torch.Tensor,
275274
k: torch.Tensor,
276275
v: torch.Tensor,
277276
cu_seqlens_q: torch.Tensor,
278277
max_seqlen_q: int,
279278
max_seqlen_k: int,
280-
softmax_scale: float,
281-
causal: bool,
282-
block_table: Optional[torch.Tensor] = None,
283-
alibi_slopes: Optional[torch.Tensor] = None,
284-
window_size: Optional[list[int]] = None,
285-
softcap: Optional[float] = 0.0,
286-
seqused_k: Optional[torch.Tensor] = None,
287-
cu_seqlens_k: Optional[torch.Tensor] = None,
279+
softmax_scale: float | None = None,
280+
causal: bool = False,
281+
out: torch.Tensor | None = None,
282+
block_table: torch.Tensor | None = None,
283+
alibi_slopes: torch.Tensor | None = None,
284+
window_size: torch.Tensor | None = None,
285+
softcap: torch.Tensor | None = 0.0,
286+
seqused_k: torch.Tensor | None = None,
287+
cu_seqlens_k: torch.Tensor | None = None,
288+
# passed in qwen vl
289+
dropout_p: float = 0.0,
288290
# The following parameters are not used in ipex kernel currently,
289291
# we keep API compatible to CUDA's.
290292
scheduler_metadata=None,
@@ -295,6 +297,8 @@ def flash_attn_varlen_func(
295297
num_splits=0,
296298
s_aux: torch.Tensor | None = None,
297299
):
300+
if out is None:
301+
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
298302
real_window_size: tuple[int, int]
299303
if window_size is None:
300304
real_window_size = (-1, -1)
@@ -303,13 +307,31 @@ def flash_attn_varlen_func(
303307
real_window_size = (window_size[0], window_size[1])
304308

305309
if block_table is None:
306-
# need check use cu_seqlens_k or seqused_k!!!
307-
ipex_ops.varlen_attention(q.contiguous(), k.contiguous(),
308-
v.contiguous(), out, cu_seqlens_q,
309-
cu_seqlens_k, max_seqlen_q, max_seqlen_k,
310-
0.0, softmax_scale, False, causal, False,
311-
None, real_window_size[0],
312-
real_window_size[1], softcap)
310+
assert cu_seqlens_k is not None, (
311+
"cu_seqlens_k can't be None when calling varlen_attention."
312+
)
313+
if softmax_scale is None:
314+
softmax_scale = q.shape[-1] ** (-0.5)
315+
ipex_ops.varlen_attention(
316+
q.contiguous(),
317+
k.contiguous(),
318+
v.contiguous(),
319+
out,
320+
cu_seqlens_q,
321+
cu_seqlens_k,
322+
None,
323+
max_seqlen_q,
324+
max_seqlen_k,
325+
0.0,
326+
softmax_scale,
327+
False,
328+
causal,
329+
False,
330+
None,
331+
real_window_size[0],
332+
real_window_size[1],
333+
-1,
334+
)
313335
return out
314336
else:
315337
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(

vllm/attention/layer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,20 @@ def maybe_get_vit_flash_attn_backend(
123123
):
124124
attn_backend = _Backend.FLASH_ATTN
125125
use_upstream_fa = True
126+
elif current_platform.is_xpu():
127+
assert attn_backend == _Backend.FLASH_ATTN, (
128+
"XPU platform only supports FLASH_ATTN as vision attention backend."
129+
)
126130
else:
127131
return _Backend.TORCH_SDPA, None
128132

129133
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
130134
if attn_backend == _Backend.ROCM_AITER_FA:
131135
from aiter import flash_attn_varlen_func
136+
elif current_platform.is_xpu():
137+
from vllm._ipex_ops import ipex_ops as ops
138+
139+
flash_attn_varlen_func = ops.flash_attn_varlen_func
132140
else:
133141
if use_upstream_fa:
134142
from flash_attn import flash_attn_varlen_func
@@ -521,21 +529,18 @@ def __init__(
521529
# If vllm native fa is selected, we use it directly.
522530
use_upstream_fa = False
523531

524-
if current_platform.is_xpu():
525-
self.attn_backend = _Backend.IPEX
526-
else:
527-
self.attn_backend = (
528-
backend
529-
if backend
530-
in {
531-
_Backend.TORCH_SDPA,
532-
_Backend.XFORMERS,
533-
_Backend.PALLAS,
534-
_Backend.ROCM_AITER_FA,
535-
_Backend.FLASH_ATTN,
536-
}
537-
else _Backend.TORCH_SDPA
538-
)
532+
self.attn_backend = (
533+
backend
534+
if backend
535+
in {
536+
_Backend.TORCH_SDPA,
537+
_Backend.XFORMERS,
538+
_Backend.PALLAS,
539+
_Backend.ROCM_AITER_FA,
540+
_Backend.FLASH_ATTN,
541+
}
542+
else _Backend.TORCH_SDPA
543+
)
539544

540545
self.attn_backend, self._flash_attn_varlen_func = (
541546
maybe_get_vit_flash_attn_backend(
@@ -610,10 +615,7 @@ def forward(
610615
out = xops.memory_efficient_attention_forward(
611616
query, key, value, scale=self.scale
612617
)
613-
elif (
614-
self.attn_backend == _Backend.TORCH_SDPA
615-
or self.attn_backend == _Backend.IPEX
616-
):
618+
elif self.attn_backend == _Backend.TORCH_SDPA:
617619
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
618620
out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
619621
out = out.transpose(1, 2)

vllm/attention/ops/vit_attn_wrappers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import einops
1616
import torch
1717

18+
from vllm.platforms import current_platform
1819
from vllm.utils.torch_utils import direct_register_custom_op
1920

2021

@@ -66,6 +67,10 @@ def flash_attn_maxseqlen_wrapper(
6667
) -> torch.Tensor:
6768
if is_rocm_aiter:
6869
from aiter import flash_attn_varlen_func
70+
elif current_platform.is_xpu():
71+
from vllm._ipex_ops import ipex_ops as ops
72+
73+
flash_attn_varlen_func = ops.flash_attn_varlen_func
6974
else:
7075
if use_upstream_fa:
7176
from flash_attn import flash_attn_varlen_func

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,6 @@ def __init__(
711711
_Backend.TORCH_SDPA,
712712
_Backend.XFORMERS,
713713
_Backend.ROCM_AITER_FA,
714-
_Backend.IPEX,
715714
}:
716715
raise RuntimeError(
717716
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
@@ -849,11 +848,7 @@ def compute_attn_mask_seqlen(
849848
) -> tuple[torch.Tensor, torch.Tensor]:
850849
max_seqlen = torch.zeros([], device=cu_seqlens.device)
851850
seqlens = torch.zeros(1, device=cu_seqlens.device)
852-
if (
853-
self.attn_backend == _Backend.FLASH_ATTN
854-
or self.attn_backend == _Backend.ROCM_AITER_FA
855-
or self.attn_backend == _Backend.IPEX
856-
):
851+
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
857852
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
858853
elif self.attn_backend == _Backend.XFORMERS:
859854
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

vllm/model_executor/models/qwen2_vl.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def __init__(
382382
_Backend.TORCH_SDPA,
383383
_Backend.XFORMERS,
384384
_Backend.ROCM_AITER_FA,
385-
_Backend.IPEX,
386385
}:
387386
raise RuntimeError(
388387
f"Qwen2-VL does not support {self.attn_backend} backend now."
@@ -458,35 +457,6 @@ def forward(
458457
causal=False,
459458
)
460459

461-
context_layer = rearrange(
462-
output, "(b s) h d -> s b (h d)", b=batch_size
463-
).contiguous()
464-
elif self.attn_backend == _Backend.IPEX:
465-
from vllm._ipex_ops import ipex_ops
466-
467-
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
468-
469-
output = torch.empty(q.shape, dtype=q.dtype, device=q.device)
470-
ipex_ops.varlen_attention(
471-
q,
472-
k,
473-
v,
474-
output,
475-
cu_seqlens,
476-
cu_seqlens,
477-
None,
478-
max_seqlen,
479-
max_seqlen,
480-
pdropout=0.0,
481-
softmax_scale=1.0 / (q.shape[-1] ** 0.5),
482-
zero_tensors=False,
483-
is_causal=False,
484-
return_softmax=False,
485-
gen_=None,
486-
window_size_left=-1,
487-
window_size_right=-1,
488-
logits_soft_cap=-1,
489-
)
490460
context_layer = rearrange(
491461
output, "(b s) h d -> s b (h d)", b=batch_size
492462
).contiguous()
@@ -819,11 +789,7 @@ def compute_attn_mask_seqlen(
819789
self, cu_seqlens: torch.Tensor
820790
) -> tuple[int | None, list[int] | None]:
821791
max_seqlen, seqlens = None, None
822-
if (
823-
self.attn_backend == _Backend.FLASH_ATTN
824-
or self.attn_backend == _Backend.ROCM_AITER_FA
825-
or self.attn_backend == _Backend.IPEX
826-
):
792+
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
827793
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
828794
elif self.attn_backend == _Backend.XFORMERS:
829795
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

vllm/platforms/xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
119119
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
120120
from vllm.attention.backends.registry import _Backend
121121

122-
return _Backend.IPEX
122+
return _Backend.FLASH_ATTN
123123

124124
@classmethod
125125
def inference_mode(cls):

0 commit comments

Comments
 (0)