Skip to content

Commit

Permalink
Align different reorder code structure
Browse files Browse the repository at this point in the history
Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 committed Feb 24, 2025
1 parent fc2ebcb commit 2da1075
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 46 deletions.
46 changes: 4 additions & 42 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
"""Reorders a tensor for load balancing the compute of causal attention."""
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return _reorder_causal_striped(tensor, cp_size, seq_dim)
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
raise ValueError(f"Unsupported {strategy=}")


Expand All @@ -345,50 +345,12 @@ def inverse_reorder_causal_load_balancing(
):
"""Inverse operation of `reorder_causal_load_balancing`."""
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return _inverse_reorder_causal_striped(tensor, cp_size, seq_dim)
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
raise ValueError(f"Unsupported {strategy=}")


def _reorder_causal_striped(tensor, cp_size: int, seq_dim: int):
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)

new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*origin_shape[seq_dim + 1 :],
]

chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)


def _inverse_reorder_causal_striped(tensor, cp_size: int, seq_dim: int):
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)

new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*origin_shape[seq_dim + 1 :],
]

chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)


def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
# bincount map with 0s
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
Expand Down
35 changes: 31 additions & 4 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def convert_to_2d(offsets, batch, max_seqlen):
batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
q, k, v, config.qkv_layout
)
assert len(batch) == 1
assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
kv_batch = q_batch = batch[0]

# Gather valid q_seqlen, which is greater than 0
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def sharded_impl(
register_primitive(FusedAttnBwdPrimitive)


def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1:
return tensor
Expand Down Expand Up @@ -1133,6 +1133,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu
return combined.reshape(ori_tensor_shape)


def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)

if not is_inverse:
new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*origin_shape[seq_dim + 1 :],
]
else:
new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*origin_shape[seq_dim + 1 :],
]

chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)


@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention."""
Expand Down Expand Up @@ -1200,7 +1227,7 @@ def ag(x):
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x

if self.config.qkv_layout.is_kvpacked():
Expand All @@ -1216,7 +1243,7 @@ def reduce_scatter_dkv(self, dk, dv):
def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)

return lax_paral_op(
x,
Expand Down

0 comments on commit 2da1075

Please sign in to comment.