Skip to content

Commit

Permalink
[JAX] Support segment_ids/pos as FA inputs (#1406)
Browse files Browse the repository at this point in the history
* POC for segment_ids/segment_pos

Signed-off-by: Reese Wang <[email protected]>

* Change segment_pos position

Signed-off-by: Reese Wang <[email protected]>

* Use RemainingArgs to solve number of parameters mismatches

Signed-off-by: Reese Wang <[email protected]>

* Test mask_descriptor for accomendating different mask representations

Signed-off-by: Reese Wang <[email protected]>

* Fix bugs

Signed-off-by: Reese Wang <[email protected]>

* Use descriptor in bwd

Signed-off-by: Reese Wang <[email protected]>

* Primitives only accepts pure jnp array

Signed-off-by: Reese Wang <[email protected]>

* segment_ids/pos support POC

Signed-off-by: Reese Wang <[email protected]>

* Move seqlens/offsets generation to mask descriptor

Signed-off-by: Reese Wang <[email protected]>

* Rename MaskDescriptor to SequenceDescriptor

Signed-off-by: Reese Wang <[email protected]>

* Generalize get_seqlens_and_offsets

Signed-off-by: Reese Wang <[email protected]>

* Utilize sequence desc on FA bwd

Signed-off-by: Reese Wang <[email protected]>

* Migrate to new API

Signed-off-by: Reese Wang <[email protected]>

* Add docstrings

Signed-off-by: Reese Wang <[email protected]>

* Remove small inputs and test different input format

Signed-off-by: Reese Wang <[email protected]>

* Fix lint

Signed-off-by: Reese Wang <[email protected]>

* Fix seed shardings

Signed-off-by: Reese Wang <[email protected]>

* Optimize sequence converting overhead

Signed-off-by: Reese Wang <[email protected]>

* Optimize seq_offsets calculation

Signed-off-by: Reese Wang <[email protected]>

* Fix up

Signed-off-by: Reese Wang <[email protected]>

* fix lint

Signed-off-by: Reese Wang <[email protected]>

* Fix conflicts

Signed-off-by: Reese Wang <[email protected]>

* Remove reduntant line

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Jan 24, 2025
1 parent 3d7ff1c commit c2c3d54
Show file tree
Hide file tree
Showing 5 changed files with 786 additions and 297 deletions.
21 changes: 6 additions & 15 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,18 @@
#
# See LICENSE for license information.

import pytest
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import (
make_causal_mask,
make_self_mask,
assert_allclose,
print_debug_tensor_stats,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
Expand All @@ -36,10 +23,11 @@
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
import pytest

from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat

DTYPES = [jnp.float16, jnp.bfloat16]
DTYPES = [jnp.bfloat16]


class TestDistributedSelfAttn:
Expand Down Expand Up @@ -141,6 +129,7 @@ def test_self_attn(
QKVLayout.BS3HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
Expand Down Expand Up @@ -205,6 +194,7 @@ def test_cross_attn(
QKVLayout.BSHD_BS2HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
Expand Down Expand Up @@ -293,6 +283,7 @@ def impl_test_context_parallel_attn(
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
Expand Down
160 changes: 95 additions & 65 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
Expand All @@ -28,12 +28,11 @@
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
fused_attn_thd,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
Expand Down Expand Up @@ -199,8 +198,8 @@ def _find_offsets(x):
).squeeze(-1)

offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
seqlens = jnp.where(seqlens, seqlens, -1)
return seqlens, offsets

Expand Down Expand Up @@ -239,11 +238,7 @@ def customcall_fused_dpa(
key,
value,
bias,
mask,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
sequence_descriptor,
dropout_rng,
**kwargs,
):
Expand All @@ -264,19 +259,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
if not qkv_layout.is_thd():
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype)
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)


class BiasShape(Enum):
Expand All @@ -290,6 +275,12 @@ class BiasShape(Enum):
_11SS = "11SS"


class SeqDescFormat(Enum):
Mask = auto()
Seqlens = auto()
SegmentIDs = auto()


@dataclass
class FusedAttnRunner:
"""
Expand All @@ -309,7 +300,8 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat

# Specifies sharding resources for distributed tests
number_of_devices: int = 1
Expand All @@ -327,11 +319,14 @@ class FusedAttnRunner:
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
return 1

def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
Expand Down Expand Up @@ -462,11 +457,11 @@ def generate_random_segment_ids(
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)

# Not include paddings
max_segment_size = sequence_length // num_segments
Expand Down Expand Up @@ -541,16 +536,47 @@ def generate_random_segment_ids(
self.window_size,
)

# Test different input formats
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
match self.seq_desc_format:
case SeqDescFormat.Mask:
pytest.skip("THD doesn't support mask input")
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
(self.seqlens_q, self.seqlens_kv),
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
else:
self.mask_for_customcall = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")

self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
Expand All @@ -565,10 +591,21 @@ def generate_random_segment_ids(
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

self.mask_pspec = PartitionSpec(
mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)
self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

match self.seq_desc_format:
case SeqDescFormat.Mask:
self.seq_desc_sharding = self.mask_sharding
case _:

def to_dp_shardings(x):
pspec = PartitionSpec(self.mesh_resource.dp_resource)
return NamedSharding(self.mesh, pspec)

self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)

if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
Expand Down Expand Up @@ -631,11 +668,7 @@ def test_forward(self):
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
Expand All @@ -659,11 +692,7 @@ def test_forward(self):
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
)
Expand Down Expand Up @@ -722,11 +751,7 @@ def grad_func(func, *args, **kwargs):
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
Expand Down Expand Up @@ -768,11 +793,7 @@ def grad_func(func, *args, **kwargs):
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
Expand Down Expand Up @@ -883,10 +904,7 @@ def check_dqkv(primitive, reference, pad, idx):
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param(
2,
2048,
Expand All @@ -897,8 +915,8 @@ def check_dqkv(primitive, reference, pad, idx):
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
Expand All @@ -915,6 +933,14 @@ def check_dqkv(primitive, reference, pad, idx):
pytest.param(True, id="SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
class TestFusedAttn:
"""
Fused attention tester
Expand Down Expand Up @@ -953,6 +979,7 @@ def _test_forward(
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
Expand All @@ -977,6 +1004,7 @@ def _test_forward(
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_forward()

Expand All @@ -1002,6 +1030,7 @@ def test_backward(
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
Expand All @@ -1024,5 +1053,6 @@ def test_backward(
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_backward()
Loading

0 comments on commit c2c3d54

Please sign in to comment.