Skip to content

Commit 60f76ba

Browse files
[Misc] Replace CUDA_VISIBLE_DEVICES in DP with torch.cuda.set_device for device selection on cuda-like devices (#27564)
Signed-off-by: ilmarkov <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent e5e076c commit 60f76ba

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,11 +1008,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10081008
# Enable different block lengths for different layers when MLA is used.
10091009
self.block_len_per_layer = list[int]()
10101010
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
1011+
self.device_id = self.tp_rank
10111012
for layer_name, cache_or_caches in xfer_buffers.items():
10121013
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
10131014

10141015
for cache in cache_list:
10151016
base_addr = cache.data_ptr()
1017+
if not self.use_host_buffer and current_platform.is_cuda_alike():
1018+
self.device_id = cache.device.index
10161019
if base_addr in seen_base_addresses:
10171020
continue
10181021

@@ -1040,7 +1043,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10401043
"All kv cache tensors must have the same size"
10411044
)
10421045
caches_data.append(
1043-
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
1046+
(base_addr, curr_tensor_size_bytes, self.device_id, "")
10441047
)
10451048

10461049
logger.debug(
@@ -1087,7 +1090,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10871090
block_offset = block_id * self.block_len_per_layer[i]
10881091
addr = base_addr + block_offset
10891092
# (addr, len, device id)
1090-
blocks_data.append((addr, kv_block_len, self.tp_rank))
1093+
blocks_data.append((addr, kv_block_len, self.device_id))
10911094

10921095
if self._use_flashinfer:
10931096
# Separate and interleave K/V regions to maintain the same
@@ -1098,12 +1101,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10981101
addr = base_addr + block_offset
10991102
# Register addresses for V cache (K registered first).
11001103
v_addr = addr + kv_block_len
1101-
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
1104+
blocks_data.append((v_addr, kv_block_len, self.device_id))
11021105
logger.debug(
1103-
"Created %s blocks for src engine %s and rank %s",
1106+
"Created %s blocks for src engine %s and rank %s on device id %s",
11041107
len(blocks_data),
11051108
self.engine_id,
11061109
self.tp_rank,
1110+
self.device_id,
11071111
)
11081112

11091113
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

vllm/v1/engine/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,18 @@ def __init__(
134134
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
135135
try:
136136
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
137+
# Adjust device control in DP for non-CUDA platforms
138+
# as well as external and ray launchers
139+
# For CUDA platforms, we use torch.cuda.set_device()
137140
with (
138141
set_device_control_env_var(vllm_config, local_dp_rank)
139-
if (data_parallel)
142+
if (
143+
data_parallel
144+
and (
145+
not current_platform.is_cuda_alike()
146+
or vllm_config.parallel_config.use_ray
147+
)
148+
)
140149
else contextlib.nullcontext()
141150
):
142151
proc.start()

vllm/v1/worker/dp_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from vllm.config import ParallelConfig
99
from vllm.distributed.parallel_state import get_dp_group
1010
from vllm.logger import init_logger
11-
from vllm.platforms import current_platform
1211
from vllm.v1.worker.ubatch_utils import (
1312
UBatchSlices,
1413
check_ubatch_thresholds,
@@ -20,7 +19,8 @@
2019

2120

2221
def _get_device_and_group(parallel_config: ParallelConfig):
23-
device = current_platform.device_type
22+
# Use the actual device assigned to the DP group, not just the device type
23+
device = get_dp_group().device
2424
group = get_dp_group().device_group
2525

2626
# Transfering this tensor from GPU to CPU will introduce a GPU sync

vllm/v1/worker/gpu_worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,29 @@ def init_device(self):
172172
if self.device_config.device.type == "cuda":
173173
# This env var set by Ray causes exceptions with graph building.
174174
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
175+
if (
176+
self.parallel_config.data_parallel_size > 1
177+
and self.parallel_config.data_parallel_size_local > 0
178+
and self.parallel_config.distributed_executor_backend
179+
not in ["ray", "external_launcher"]
180+
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
181+
):
182+
# Use local DP rank if available, otherwise use global DP rank.
183+
dp_local_rank = self.parallel_config.data_parallel_rank_local
184+
if dp_local_rank is None:
185+
dp_local_rank = self.parallel_config.data_parallel_rank
186+
187+
tp_pp_world_size = (
188+
self.parallel_config.pipeline_parallel_size
189+
* self.parallel_config.tensor_parallel_size
190+
)
191+
192+
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
193+
self.local_rank += dp_local_rank * tp_pp_world_size
194+
assert self.local_rank < torch.cuda.device_count(), (
195+
f"DP adjusted local rank {self.local_rank} is out of bounds. "
196+
)
197+
175198
self.device = torch.device(f"cuda:{self.local_rank}")
176199
current_platform.set_device(self.device)
177200

0 commit comments

Comments
 (0)