@@ -270,21 +270,23 @@ def reshape_and_cache_flash(
270270
271271 @staticmethod
272272 def flash_attn_varlen_func (
273- out : torch .Tensor ,
274273 q : torch .Tensor ,
275274 k : torch .Tensor ,
276275 v : torch .Tensor ,
277276 cu_seqlens_q : torch .Tensor ,
278277 max_seqlen_q : int ,
279278 max_seqlen_k : int ,
280- softmax_scale : float ,
281- causal : bool ,
282- block_table : Optional [torch .Tensor ] = None ,
283- alibi_slopes : Optional [torch .Tensor ] = None ,
284- window_size : Optional [list [int ]] = None ,
285- softcap : Optional [float ] = 0.0 ,
286- seqused_k : Optional [torch .Tensor ] = None ,
287- cu_seqlens_k : Optional [torch .Tensor ] = None ,
279+ softmax_scale : float | None = None ,
280+ causal : bool = False ,
281+ out : torch .Tensor | None = None ,
282+ block_table : torch .Tensor | None = None ,
283+ alibi_slopes : torch .Tensor | None = None ,
284+ window_size : torch .Tensor | None = None ,
285+ softcap : torch .Tensor | None = 0.0 ,
286+ seqused_k : torch .Tensor | None = None ,
287+ cu_seqlens_k : torch .Tensor | None = None ,
288+ # passed in qwen vl
289+ dropout_p : float = 0.0 ,
288290 # The following parameters are not used in ipex kernel currently,
289291 # we keep API compatible to CUDA's.
290292 scheduler_metadata = None ,
@@ -295,6 +297,8 @@ def flash_attn_varlen_func(
295297 num_splits = 0 ,
296298 s_aux : torch .Tensor | None = None ,
297299 ):
300+ if out is None :
301+ out = torch .empty (q .shape , dtype = q .dtype , device = q .device )
298302 real_window_size : tuple [int , int ]
299303 if window_size is None :
300304 real_window_size = (- 1 , - 1 )
@@ -303,13 +307,31 @@ def flash_attn_varlen_func(
303307 real_window_size = (window_size [0 ], window_size [1 ])
304308
305309 if block_table is None :
306- # need check use cu_seqlens_k or seqused_k!!!
307- ipex_ops .varlen_attention (q .contiguous (), k .contiguous (),
308- v .contiguous (), out , cu_seqlens_q ,
309- cu_seqlens_k , max_seqlen_q , max_seqlen_k ,
310- 0.0 , softmax_scale , False , causal , False ,
311- None , real_window_size [0 ],
312- real_window_size [1 ], softcap )
310+ assert cu_seqlens_k is not None , (
311+ "cu_seqlens_k can't be None when calling varlen_attention."
312+ )
313+ if softmax_scale is None :
314+ softmax_scale = q .shape [- 1 ] ** (- 0.5 )
315+ ipex_ops .varlen_attention (
316+ q .contiguous (),
317+ k .contiguous (),
318+ v .contiguous (),
319+ out ,
320+ cu_seqlens_q ,
321+ cu_seqlens_k ,
322+ None ,
323+ max_seqlen_q ,
324+ max_seqlen_k ,
325+ 0.0 ,
326+ softmax_scale ,
327+ False ,
328+ causal ,
329+ False ,
330+ None ,
331+ real_window_size [0 ],
332+ real_window_size [1 ],
333+ - 1 ,
334+ )
313335 return out
314336 else :
315337 return ipex .llm .modules .PagedAttention .flash_attn_varlen_func (
0 commit comments