Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/ut/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def mock_distributed():
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
mock_dp_group.return_value.world_size = 2
mock_pp_group.return_value.world_size = 2
yield


Expand Down
33 changes: 22 additions & 11 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
get_tp_group, get_world_group,
get_pp_group, get_tp_group,
get_world_group,
init_model_parallel_group)

import vllm_ascend.envs as envs_ascend
Expand Down Expand Up @@ -185,6 +186,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
).flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group().world_size
global_dp_size = get_dp_group().world_size
global_pp_size = get_pp_group().world_size
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
flashcomm2_otp_size)

Expand All @@ -197,18 +199,27 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if flashcomm2_otp_size > 1:
otp_group_ranks = []
odp_group_ranks: list[list[int]] = [
[] for _ in range(flashcomm2_otp_size * global_dp_size)
[] for _ in range(flashcomm2_otp_size * global_dp_size *
global_pp_size)
]

for dp_group_index in range(global_dp_size):
for i in range(num_fc2_oproj_tensor_parallel_groups):
ranks = []
for j in range(flashcomm2_otp_size):
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
ranks.append(rank_idx)
odp_group_index = dp_group_index * flashcomm2_otp_size + j
odp_group_ranks[odp_group_index].append(rank_idx)
otp_group_ranks.append(ranks)
for pp_group_index in range(global_pp_size):
dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index
tp_base_rank = dp_pp_serial_index * global_tp_size
odp_base_index = dp_pp_serial_index * flashcomm2_otp_size

for i in range(num_fc2_oproj_tensor_parallel_groups):
ranks = []
for j in range(flashcomm2_otp_size):
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
assert tp_local_rank < global_tp_size
global_rank = tp_base_rank + tp_local_rank
ranks.append(global_rank)

odp_group_index = odp_base_index + j
odp_group_ranks[odp_group_index].append(
global_rank)
otp_group_ranks.append(ranks)

_FLASHCOMM2_OTP = init_model_parallel_group(
otp_group_ranks,
Expand Down
Loading