Skip to content

Commit

Permalink
Add missed arguments of apply_rotary_pos_emb in MHA (#1296)
Browse files Browse the repository at this point in the history
* add missed arguments of apply_rotary_pos_emb in MHA

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove an unnecessary f

Signed-off-by: Xiaowei Ren <[email protected]>

* add one more assert for cp_group len

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
xrennvidia and pre-commit-ci[bot] authored Oct 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 8bdb54f commit ed1e85c
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -8495,6 +8495,8 @@ def __init__(
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_attention_heads = num_attention_heads
self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0

kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)

@@ -8713,6 +8715,21 @@ def set_context_parallel_group(
across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
"""
if isinstance(cp_group, dist_group_type):
self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert (
cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1])
cp_rank_p2p = get_distributed_rank(cp_group[1])
self.cp_size = cp_size_a2a * cp_size_p2p
self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a

# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
@@ -9047,8 +9064,24 @@ def forward(
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
query_layer = apply_rotary_pos_emb(
query_layer,
q_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
)
key_layer = apply_rotary_pos_emb(
key_layer,
k_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
)

# ===========================
# Core attention computation

0 comments on commit ed1e85c

Please sign in to comment.