From ed1e85c4a6502999541f46d611622ec283aa9680 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:20:59 -0700 Subject: [PATCH] Add missed arguments of apply_rotary_pos_emb in MHA (#1296) * add missed arguments of apply_rotary_pos_emb in MHA Signed-off-by: Xiaowei Ren * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove an unnecessary f Signed-off-by: Xiaowei Ren * add one more assert for cp_group len Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 37 +++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8e8a3d9e37..3d72c6a9b3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8495,6 +8495,8 @@ def __init__( self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_attention_heads = num_attention_heads self.return_bias = return_bias + self.cp_size = 1 + self.cp_rank = 0 kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -8713,6 +8715,21 @@ def set_context_parallel_group( across each CP sub-group (e.g., via NVLink), then exchanging KV with p2p between sub-groups (e.g., via IBLink). """ + if isinstance(cp_group, dist_group_type): + self.cp_size = get_distributed_world_size(cp_group) + self.cp_rank = get_distributed_rank(cp_group) + elif isinstance(cp_group, list): + assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + cp_comm_type == "a2a+p2p" + ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + cp_size_a2a = get_distributed_world_size(cp_group[0]) + cp_rank_a2a = get_distributed_rank(cp_group[0]) + cp_size_p2p = get_distributed_world_size(cp_group[1]) + cp_rank_p2p = get_distributed_rank(cp_group[1]) + self.cp_size = cp_size_a2a * cp_size_p2p + self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a + # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): if index == 0: @@ -9047,8 +9064,24 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = apply_rotary_pos_emb( + query_layer, + q_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_q, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) + key_layer = apply_rotary_pos_emb( + key_layer, + k_pos_emb, + self.qkv_format, + fused=True, + cu_seqlens=cu_seqlens_kv, + cp_size=self.cp_size, + cp_rank=self.cp_rank, + ) # =========================== # Core attention computation