Skip to content

Commit 4d6fe66

Browse files
committed
update
Signed-off-by: Yan Ma <[email protected]>
1 parent 15e31c0 commit 4d6fe66

File tree

4 files changed

+7
-13
lines changed

4 files changed

+7
-13
lines changed

vllm/_ipex_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def flash_attn_varlen_func(
281281
out: torch.Tensor | None = None,
282282
block_table: torch.Tensor | None = None,
283283
alibi_slopes: torch.Tensor | None = None,
284-
window_size: torch.Tensor | None = None,
285-
softcap: torch.Tensor | None = 0.0,
284+
window_size: list[int] | None = None,
285+
softcap: float | None = 0.0,
286286
seqused_k: torch.Tensor | None = None,
287287
cu_seqlens_k: torch.Tensor | None = None,
288288
# passed in qwen vl

vllm/attention/layer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,18 @@ def maybe_get_vit_flash_attn_backend(
127127
assert attn_backend == _Backend.FLASH_ATTN, (
128128
"XPU platform only supports FLASH_ATTN as vision attention backend."
129129
)
130+
use_upstream_fa = False
130131
else:
131132
return _Backend.TORCH_SDPA, None
132133

133134
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
134135
if attn_backend == _Backend.ROCM_AITER_FA:
135136
from aiter import flash_attn_varlen_func
136-
elif current_platform.is_xpu():
137-
from vllm._ipex_ops import ipex_ops as ops
138-
139-
flash_attn_varlen_func = ops.flash_attn_varlen_func
140137
else:
141138
if use_upstream_fa:
142139
from flash_attn import flash_attn_varlen_func
143140
else:
144-
from vllm.vllm_flash_attn import flash_attn_varlen_func
141+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
145142
else:
146143
flash_attn_varlen_func = None
147144

vllm/attention/ops/vit_attn_wrappers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import einops
1616
import torch
1717

18-
from vllm.platforms import current_platform
1918
from vllm.utils.torch_utils import direct_register_custom_op
2019

2120

@@ -67,15 +66,11 @@ def flash_attn_maxseqlen_wrapper(
6766
) -> torch.Tensor:
6867
if is_rocm_aiter:
6968
from aiter import flash_attn_varlen_func
70-
elif current_platform.is_xpu():
71-
from vllm._ipex_ops import ipex_ops as ops
72-
73-
flash_attn_varlen_func = ops.flash_attn_varlen_func
7469
else:
7570
if use_upstream_fa:
7671
from flash_attn import flash_attn_varlen_func
7772
else:
78-
from vllm.vllm_flash_attn import flash_attn_varlen_func
73+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
7974
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
8075
output = flash_attn_varlen_func(
8176
q,

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def __init__(
364364

365365
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
366366
self.use_upstream_fa = True
367+
if current_platform.is_xpu():
368+
self.use_upstream_fa = False
367369
self.is_flash_attn_backend = self.attn_backend in {
368370
_Backend.FLASH_ATTN,
369371
_Backend.ROCM_AITER_FA,

0 commit comments

Comments
 (0)