File tree Expand file tree Collapse file tree 1 file changed +10
-10
lines changed
Expand file tree Collapse file tree 1 file changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -460,16 +460,16 @@ def _segment_ids_pos_to_seqlens_offsets(
460460 lambda x , y : jnp .equal (x , y ) * x ,
461461 )
462462 attn_mask = segment_mask
463- if attn_mask_type .is_causal ():
464- causal_mask = make_attention_mask (
465- segment_pos_q ,
466- segment_pos_kv ,
467- jnp .greater_equal ,
468- )
469- attn_mask = jnp .logical_and (segment_mask , causal_mask )
470-
471- swa_mask = make_swa_mask (segment_pos_q , segment_pos_kv , window_size , dtype = jnp .bool )
472- attn_mask = jnp .logical_and (attn_mask , swa_mask )
463+ # if attn_mask_type.is_causal():
464+ # causal_mask = make_attention_mask(
465+ # segment_pos_q,
466+ # segment_pos_kv,
467+ # jnp.greater_equal,
468+ # )
469+ # attn_mask = jnp.logical_and(segment_mask, causal_mask)
470+
471+ # swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
472+ # attn_mask = jnp.logical_and(attn_mask, swa_mask)
473473
474474 attn_mask_with_id = jnp .where (attn_mask , segment_mask_with_id , 0 )
475475 q_seqlen , q_offset , kv_seqlen , kv_offset = _mask_to_seqlens_offset (
You can’t perform that action at this time.
0 commit comments