Skip to content

Commit 6a75356

Browse files
Try by only commenting out SW part and keeping causal part
Signed-off-by: Kshitij Lakhani <[email protected]>
1 parent 7a3ceee commit 6a75356

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

transformer_engine/jax/attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,13 +460,13 @@ def _segment_ids_pos_to_seqlens_offsets(
460460
lambda x, y: jnp.equal(x, y) * x,
461461
)
462462
attn_mask = segment_mask
463-
# if attn_mask_type.is_causal():
464-
# causal_mask = make_attention_mask(
465-
# segment_pos_q,
466-
# segment_pos_kv,
467-
# jnp.greater_equal,
468-
# )
469-
# attn_mask = jnp.logical_and(segment_mask, causal_mask)
463+
if attn_mask_type.is_causal():
464+
causal_mask = make_attention_mask(
465+
segment_pos_q,
466+
segment_pos_kv,
467+
jnp.greater_equal,
468+
)
469+
attn_mask = jnp.logical_and(segment_mask, causal_mask)
470470

471471
# swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
472472
# attn_mask = jnp.logical_and(attn_mask, swa_mask)

0 commit comments

Comments
 (0)