diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 15a5c50986b..4a9109166ed 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -24,11 +24,13 @@ def mock_distributed(): patch('torch.distributed.get_backend', return_value='nccl'), \ patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \ - patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group: + patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \ + patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group: mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() mock_tp_group.return_value.world_size = 4 mock_dp_group.return_value.world_size = 2 + mock_pp_group.return_value.world_size = 2 yield diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 9b5dde0fee1..00de0627b4c 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -3,7 +3,8 @@ import torch from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, - get_tp_group, get_world_group, + get_pp_group, get_tp_group, + get_world_group, init_model_parallel_group) import vllm_ascend.envs as envs_ascend @@ -185,6 +186,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): ).flashcomm2_oproj_tensor_parallel_size global_tp_size = get_tp_group().world_size global_dp_size = get_dp_group().world_size + global_pp_size = get_pp_group().world_size num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) @@ -197,18 +199,27 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): if flashcomm2_otp_size > 1: otp_group_ranks = [] odp_group_ranks: list[list[int]] = [ - [] for _ in range(flashcomm2_otp_size * global_dp_size) + [] for _ in range(flashcomm2_otp_size * global_dp_size * + global_pp_size) ] - for dp_group_index in range(global_dp_size): - for i in range(num_fc2_oproj_tensor_parallel_groups): - ranks = [] - for j in range(flashcomm2_otp_size): - rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups - ranks.append(rank_idx) - odp_group_index = dp_group_index * flashcomm2_otp_size + j - odp_group_ranks[odp_group_index].append(rank_idx) - otp_group_ranks.append(ranks) + for pp_group_index in range(global_pp_size): + dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index + tp_base_rank = dp_pp_serial_index * global_tp_size + odp_base_index = dp_pp_serial_index * flashcomm2_otp_size + + for i in range(num_fc2_oproj_tensor_parallel_groups): + ranks = [] + for j in range(flashcomm2_otp_size): + tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups + assert tp_local_rank < global_tp_size + global_rank = tp_base_rank + tp_local_rank + ranks.append(global_rank) + + odp_group_index = odp_base_index + j + odp_group_ranks[odp_group_index].append( + global_rank) + otp_group_ranks.append(ranks) _FLASHCOMM2_OTP = init_model_parallel_group( otp_group_ranks,