Skip to content

Commit

Permalink
[JAX] Expose context parallel params to jax DPA api (#1292)
Browse files Browse the repository at this point in the history
Exposed context parallel params to DPA api

Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>

---------

Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>
Co-authored-by: Michael Goldfarb <[email protected]>
  • Loading branch information
kocchop and mgoldfarb-nvidia authored Nov 4, 2024
1 parent c42beef commit d725686
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
44 changes: 26 additions & 18 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def test_self_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -268,7 +267,6 @@ def test_cross_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -425,22 +423,32 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP

# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand Down
6 changes: 0 additions & 6 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
Expand All @@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False

return True


Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

@nn.compact
def __call__(
Expand Down Expand Up @@ -308,6 +310,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
Expand All @@ -331,6 +335,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
Expand All @@ -349,6 +355,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
Expand Down Expand Up @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Optimization parameters
-----------------------
Expand All @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

@nn.compact
def __call__(
Expand Down Expand Up @@ -614,6 +627,8 @@ def __call__(
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)

return x
Expand Down

0 comments on commit d725686

Please sign in to comment.