Skip to content

Commit 42ba6fb

Browse files
jikunshangyma11
authored andcommitted
try add varlen_attention in flash_attn_varlen_func interface, to support prefill only scenario
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 4a770f2 commit 42ba6fb

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

vllm/_ipex_ops.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,16 @@ def flash_attn_varlen_func(
275275
k: torch.Tensor,
276276
v: torch.Tensor,
277277
cu_seqlens_q: torch.Tensor,
278-
seqused_k: torch.Tensor, # we don't support this in ipex kernel
279278
max_seqlen_q: int,
280279
max_seqlen_k: int,
281280
softmax_scale: float,
282281
causal: bool,
283-
block_table: torch.Tensor,
284-
alibi_slopes: torch.Tensor | None,
285-
window_size: list[int] | None = None,
286-
softcap: float | None = 0.0,
287-
cu_seqlens_k: torch.Tensor | None = None,
282+
block_table: Optional[torch.Tensor] = None,
283+
alibi_slopes: Optional[torch.Tensor] = None,
284+
window_size: Optional[list[int]] = None,
285+
softcap: Optional[float] = 0.0,
286+
seqused_k: Optional[torch.Tensor] = None,
287+
cu_seqlens_k: Optional[torch.Tensor] = None,
288288
# The following parameters are not used in ipex kernel currently,
289289
# we keep API compatible to CUDA's.
290290
scheduler_metadata=None,
@@ -301,25 +301,37 @@ def flash_attn_varlen_func(
301301
else:
302302
assert len(window_size) == 2
303303
real_window_size = (window_size[0], window_size[1])
304-
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
305-
out,
306-
q.contiguous(),
307-
k,
308-
v,
309-
cu_seqlens_q,
310-
seqused_k,
311-
max_seqlen_q,
312-
max_seqlen_k,
313-
softmax_scale,
314-
causal,
315-
block_table,
316-
alibi_slopes,
317-
softcap=softcap,
318-
window_size_left=real_window_size[0],
319-
window_size_right=real_window_size[1],
320-
k_scale=1.0,
321-
v_scale=1.0,
322-
)
304+
305+
if block_table is None:
306+
# need check use cu_seqlens_k or seqused_k!!!
307+
ipex_ops.varlen_attention(q.contiguous(), k.contiguous(),
308+
v.contiguous(), out, cu_seqlens_q,
309+
cu_seqlens_k, max_seqlen_q, max_seqlen_k,
310+
0.0, softmax_scale, False, causal, False,
311+
None, real_window_size[0],
312+
real_window_size[1], softcap)
313+
return out
314+
else:
315+
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
316+
out,
317+
q.contiguous(),
318+
k,
319+
v,
320+
cu_seqlens_q,
321+
seqused_k,
322+
max_seqlen_q,
323+
max_seqlen_k,
324+
softmax_scale,
325+
causal,
326+
block_table,
327+
alibi_slopes,
328+
sink=s_aux,
329+
softcap=softcap,
330+
window_size_left=real_window_size[0],
331+
window_size_right=real_window_size[1],
332+
k_scale=1.0,
333+
v_scale=1.0,
334+
)
323335

324336
@staticmethod
325337
def get_scheduler_metadata(

0 commit comments

Comments
 (0)