Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,11 +1008,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
self.device_id = self.tp_rank
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]

for cache in cache_list:
base_addr = cache.data_ptr()
if not self.use_host_buffer and current_platform.is_cuda_alike():
self.device_id = cache.device.index
if base_addr in seen_base_addresses:
continue

Expand Down Expand Up @@ -1040,7 +1043,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"All kv cache tensors must have the same size"
)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
(base_addr, curr_tensor_size_bytes, self.device_id, "")
)

logger.debug(
Expand Down Expand Up @@ -1087,7 +1090,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
blocks_data.append((addr, kv_block_len, self.device_id))

if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
Expand All @@ -1098,12 +1101,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
blocks_data.append((v_addr, kv_block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s",
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)

descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,18 @@ def __init__(
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device()
with (
set_device_control_env_var(vllm_config, local_dp_rank)
if (data_parallel)
if (
data_parallel
and (
not current_platform.is_cuda_alike()
or vllm_config.parallel_config.use_ray
)
)
else contextlib.nullcontext()
):
proc.start()
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
Expand All @@ -20,7 +19,8 @@


def _get_device_and_group(parallel_config: ParallelConfig):
device = current_platform.device_type
# Use the actual device assigned to the DP group, not just the device type
device = get_dp_group().device
group = get_dp_group().device_group

# Transfering this tensor from GPU to CPU will introduce a GPU sync
Expand Down
23 changes: 23 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,29 @@ def init_device(self):
if self.device_config.device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.distributed_executor_backend
not in ["ray", "external_launcher"]
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank

tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size
* self.parallel_config.tensor_parallel_size
)

# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)

self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)

Expand Down