Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] THD ring attention #1454

Merged
merged 12 commits into from
Mar 3, 2025
60 changes: 37 additions & 23 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
)


Expand Down Expand Up @@ -210,29 +211,29 @@ def test_cross_attn(
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"attn_mask_type",
"qkv_layout, attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
class TestDistributedContextParallelSelfAttn:

Expand Down Expand Up @@ -265,7 +266,6 @@ def impl_test_context_parallel_attn(
data_shape = batch, seqlen, num_head, hidden

num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

runner = FusedAttnRunner(
batch,
Expand All @@ -282,7 +282,7 @@ def impl_test_context_parallel_attn(
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
SeqDescFormat.SegmentIDs,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
Expand All @@ -297,7 +297,7 @@ def check_has_backend_for_mask(mask_type):
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
mask_type,
dropout_prob,
num_head,
num_kv_heads,
Expand Down Expand Up @@ -340,6 +340,8 @@ def test_context_parallel_allgather_attn(
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
Expand Down Expand Up @@ -377,7 +379,10 @@ def test_context_parallel_ring_attn(
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

self.impl_test_context_parallel_attn(
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")

return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand All @@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing:
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0

ref = tensor.copy()

reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])

reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)

assert jnp.array_equal(inversed, ref)
89 changes: 59 additions & 30 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
Expand Down Expand Up @@ -347,9 +349,9 @@ def _check_configs(self):
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
Expand Down Expand Up @@ -500,7 +502,8 @@ def generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
Expand Down Expand Up @@ -536,6 +539,30 @@ def generate_random_segment_ids(
self.window_size,
)

if self.cp_size > 1 and self.cp_load_balanced:
if self.qkv_layout.is_thd():
reorder_strategy = ReorderStrategy.Striped
else:
reorder_strategy = ReorderStrategy.DualChunkSwap

seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x

# Test different input formats
if self.qkv_layout.is_thd():
match self.seq_desc_format:
Expand All @@ -548,8 +575,14 @@ def generate_random_segment_ids(
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
Expand Down Expand Up @@ -605,7 +638,12 @@ def generate_random_segment_ids(
case _:

def to_dp_shardings(x):
pspec = PartitionSpec(self.mesh_resource.dp_resource)
if x.ndim == 1:
pspec = PartitionSpec(self.mesh_resource.dp_resource)
else:
pspec = PartitionSpec(
self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
)
return NamedSharding(self.mesh, pspec)

self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
Expand Down Expand Up @@ -637,24 +675,6 @@ def to_dp_shardings(x):
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

# Softmax aux sharding

if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x

def test_forward(self):
"""
Test forward without JIT
Expand Down Expand Up @@ -733,15 +753,24 @@ def test_backward(self):

self._setup_inputs()

def grad_func(func, *args, **kwargs):
def grad_func(func, *args, cp_reverse_out=False, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
if not cp_reverse_out:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
func(*args, **kwargs),
)
else:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
self.cp_inverse_reorder_fn(func(*args, **kwargs)),
)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
Expand Down Expand Up @@ -787,7 +816,7 @@ def grad_func(func, *args, **kwargs):
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
),
arg_nums,
),
Expand Down
Loading