Skip to content

Commit 7a3ceee

Browse files
Test to see if SWA ans Causal compute can be removed from seqlens and offsets calc
Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent afd15a1 commit 7a3ceee

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

transformer_engine/jax/attention.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -460,16 +460,16 @@ def _segment_ids_pos_to_seqlens_offsets(
460460
lambda x, y: jnp.equal(x, y) * x,
461461
)
462462
attn_mask = segment_mask
463-
if attn_mask_type.is_causal():
464-
causal_mask = make_attention_mask(
465-
segment_pos_q,
466-
segment_pos_kv,
467-
jnp.greater_equal,
468-
)
469-
attn_mask = jnp.logical_and(segment_mask, causal_mask)
470-
471-
swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
472-
attn_mask = jnp.logical_and(attn_mask, swa_mask)
463+
# if attn_mask_type.is_causal():
464+
# causal_mask = make_attention_mask(
465+
# segment_pos_q,
466+
# segment_pos_kv,
467+
# jnp.greater_equal,
468+
# )
469+
# attn_mask = jnp.logical_and(segment_mask, causal_mask)
470+
471+
# swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
472+
# attn_mask = jnp.logical_and(attn_mask, swa_mask)
473473

474474
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
475475
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(

0 commit comments

Comments
 (0)