Skip to content

Commit

Permalink
[JAX] Fix issues when mask/sequence_descriptor is None (#1477)
Browse files Browse the repository at this point in the history
Fix issues when mask/sequence_descriptor is None

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
zlsh80826 and phu0ngng authored Feb 14, 2025
1 parent 45e9d8b commit dfbf4dd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
17 changes: 10 additions & 7 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,16 @@ def generate_random_segment_ids(
else:
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
if self.attn_mask_type == AttnMaskType.NO_MASK:
self.sequence_desciptor = None
else:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def fused_attn(
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
"""
if isinstance(sequence_descriptor, jnp.ndarray):
if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray):
warnings.warn(
"Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
+ "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
Expand Down

0 comments on commit dfbf4dd

Please sign in to comment.