@@ -275,16 +275,16 @@ def flash_attn_varlen_func(
275275 k : torch .Tensor ,
276276 v : torch .Tensor ,
277277 cu_seqlens_q : torch .Tensor ,
278- seqused_k : torch .Tensor , # we don't support this in ipex kernel
279278 max_seqlen_q : int ,
280279 max_seqlen_k : int ,
281280 softmax_scale : float ,
282281 causal : bool ,
283- block_table : torch .Tensor ,
284- alibi_slopes : torch .Tensor | None ,
285- window_size : list [int ] | None = None ,
286- softcap : float | None = 0.0 ,
287- cu_seqlens_k : torch .Tensor | None = None ,
282+ block_table : Optional [torch .Tensor ] = None ,
283+ alibi_slopes : Optional [torch .Tensor ] = None ,
284+ window_size : Optional [list [int ]] = None ,
285+ softcap : Optional [float ] = 0.0 ,
286+ seqused_k : Optional [torch .Tensor ] = None ,
287+ cu_seqlens_k : Optional [torch .Tensor ] = None ,
288288 # The following parameters are not used in ipex kernel currently,
289289 # we keep API compatible to CUDA's.
290290 scheduler_metadata = None ,
@@ -301,25 +301,37 @@ def flash_attn_varlen_func(
301301 else :
302302 assert len (window_size ) == 2
303303 real_window_size = (window_size [0 ], window_size [1 ])
304- return ipex .llm .modules .PagedAttention .flash_attn_varlen_func (
305- out ,
306- q .contiguous (),
307- k ,
308- v ,
309- cu_seqlens_q ,
310- seqused_k ,
311- max_seqlen_q ,
312- max_seqlen_k ,
313- softmax_scale ,
314- causal ,
315- block_table ,
316- alibi_slopes ,
317- softcap = softcap ,
318- window_size_left = real_window_size [0 ],
319- window_size_right = real_window_size [1 ],
320- k_scale = 1.0 ,
321- v_scale = 1.0 ,
322- )
304+
305+ if block_table is None :
306+ # need check use cu_seqlens_k or seqused_k!!!
307+ ipex_ops .varlen_attention (q .contiguous (), k .contiguous (),
308+ v .contiguous (), out , cu_seqlens_q ,
309+ cu_seqlens_k , max_seqlen_q , max_seqlen_k ,
310+ 0.0 , softmax_scale , False , causal , False ,
311+ None , real_window_size [0 ],
312+ real_window_size [1 ], softcap )
313+ return out
314+ else :
315+ return ipex .llm .modules .PagedAttention .flash_attn_varlen_func (
316+ out ,
317+ q .contiguous (),
318+ k ,
319+ v ,
320+ cu_seqlens_q ,
321+ seqused_k ,
322+ max_seqlen_q ,
323+ max_seqlen_k ,
324+ softmax_scale ,
325+ causal ,
326+ block_table ,
327+ alibi_slopes ,
328+ sink = s_aux ,
329+ softcap = softcap ,
330+ window_size_left = real_window_size [0 ],
331+ window_size_right = real_window_size [1 ],
332+ k_scale = 1.0 ,
333+ v_scale = 1.0 ,
334+ )
323335
324336 @staticmethod
325337 def get_scheduler_metadata (
0 commit comments