Skip to content

Commit

Permalink
[JAX] Expose THD format to the flax module (#1480)
Browse files Browse the repository at this point in the history
* Expose THD to flex MHA module

Signed-off-by: Reese Wang <[email protected]>

* Enhance docs

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
zlsh80826 and phu0ngng authored Feb 14, 2025
1 parent dfbf4dd commit af7b2b4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
7 changes: 6 additions & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,8 @@ def partition(config, mesh, arg_infos, result_infos):
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
Expand Down Expand Up @@ -1042,7 +1044,10 @@ def partition(config, mesh, arg_infos, result_infos):
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)

def sharded_impl(
Expand Down
71 changes: 55 additions & 16 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..softmax import SoftmaxType
Expand Down Expand Up @@ -267,6 +267,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

Expand All @@ -276,7 +277,7 @@ def __call__(
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
sequence_descriptor: Optional[SequenceDescriptor] = None,
bias: Optional[Array] = None,
*,
dropout_rng: Optional[PRNGKey] = None,
Expand All @@ -293,8 +294,7 @@ def __call__(
scale_factor = self.scale_factor
del self.scale_factor

# TODO(rewang): integrate THD format
if self.qkv_layout == QKVLayout.BS3HD:
if self.qkv_layout.is_qkvpacked():
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
key: ignore
Expand All @@ -306,7 +306,7 @@ def __call__(
x = fused_attn(
(qkv_packed,),
bias,
mask,
sequence_descriptor,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
Expand All @@ -315,10 +315,11 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat
query: query tensor, shape = [..., h, d]
key: kvpacked tensor, shape = [..., 2, h, d]
Expand All @@ -331,7 +332,7 @@ def __call__(
x = fused_attn(
(query, kv_packed),
bias,
mask,
sequence_descriptor,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
Expand All @@ -340,18 +341,19 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(
(query, key, value),
bias,
mask,
sequence_descriptor,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
Expand All @@ -360,6 +362,7 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
Expand Down Expand Up @@ -437,6 +440,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'.
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
Expand All @@ -451,13 +456,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
qkv_layout: str, default = 'bshd_bshd_bshd'
Specifies the dimensional layout format for the query, key, and value tensors in __call__().
It indicates how the inputs are processed.
Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where
Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}.
* bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
key and value arguments in :attr:`__call__()` are ignored in this layout.
* bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
* bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
* t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple
sequences to be packed in a batch, also known as sequence packing.
Explanation of denotations:
Expand All @@ -476,6 +483,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
max_segments_per_seq: Optional[int], default = 1
The maximum number of segments per sequence, also used for THD format (sequence packing).
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Expand All @@ -502,6 +511,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

Expand All @@ -511,10 +521,11 @@ def __call__(
query: Array,
key: Array,
value: Array,
mask: Optional[Array] = None,
sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None,
bias: Optional[Array] = None,
*,
deterministic: bool = False,
mask: Optional[Union[SequenceDescriptor, Array]] = None,
) -> Array:
"""
Parameters
Expand Down Expand Up @@ -542,6 +553,15 @@ def __call__(
Output tensors.
"""

if mask is not None:
if sequence_descriptor is not None:
raise ValueError(
"sequence_descriptor and mask cannot be provided at the same time."
)
warnings.warn("mask is deprecated, please use sequence_descriptor instead.")
sequence_descriptor = mask
del mask

# For internal API, we use enum to maintain
if self.attn_bias_type is None:
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
Expand Down Expand Up @@ -604,16 +624,18 @@ def __call__(

if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout == QKVLayout.BS3HD:
if qkv_layout.is_qkvpacked():
query, key, value = jnp.split(query, [1, 2], axis=-3)
query, key, value = map(
functools.partial(jnp.squeeze, axis=-3), [query, key, value]
)
elif qkv_layout == QKVLayout.BSHD_BS2HD:
elif qkv_layout.is_kvpacked():
key, value = jnp.split(key, [1], axis=-3)
key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value])
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
assert qkv_layout.is_separate()

assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray)

x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout,
Expand All @@ -625,7 +647,15 @@ def __call__(
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
)(
query,
key,
value,
sequence_descriptor,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic,
)
else:
x = _FusedDotProductAttention(
attention_dropout=self.attention_dropout,
Expand All @@ -637,9 +667,18 @@ def __call__(
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
)(
query,
key,
value,
sequence_descriptor,
bias,
dropout_rng=dropout_rng,
deterministic=deterministic,
)

return x

Expand Down

0 comments on commit af7b2b4

Please sign in to comment.