Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,23 @@ def add_parser_api_server():
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
model_format = ArgumentHelper.model_format(pt_group)
ArgumentHelper.dp(pt_group)
dp_act = ArgumentHelper.dp(pt_group)
num_nodes_act = ArgumentHelper.num_nodes(pt_group)
ArgumentHelper.ep(pt_group)
ArgumentHelper.enable_microbatch(pt_group)
ArgumentHelper.enable_eplb(pt_group)
ArgumentHelper.enable_metrics(pt_group)
ArgumentHelper.role(pt_group)
ArgumentHelper.migration_backend(pt_group)
# multi-node serving args
ArgumentHelper.node_rank(parser)
ArgumentHelper.num_nodes(parser)
node_rank_act = ArgumentHelper.node_rank(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
# common engine args
tb_group._group_actions.append(dtype_act)
tb_group._group_actions.append(tp_act)
tb_group._group_actions.append(dp_act)
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(max_batch_size_act)
tb_group._group_actions.append(cache_max_entry_act)
Expand All @@ -189,10 +190,13 @@ def add_parser_api_server():
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(num_nodes_act)
tb_group._group_actions.append(node_rank_act)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.ngpus_per_node(tb_group)

# vlm args
vision_group = parser.add_argument_group('Vision model arguments')
Expand Down Expand Up @@ -342,6 +346,10 @@ def api_server(args):
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(dtype=args.dtype,
tp=args.tp,
dp=args.dp,
nnodes=args.nnodes,
ngpus_per_node=args.ngpus_per_node,
node_rank=args.node_rank,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def num_nodes(parser):

return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums')

@staticmethod
def ngpus_per_node(parser):
"""Add argument ngpus_per_node to parser."""

return parser.add_argument('--ngpus-per-node', type=int, default=None, help='The total gpu nums per node')

@staticmethod
def session_id(parser):
"""Add argument session_id to parser."""
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,17 @@ class TurbomindEngineConfig:
model_format: Optional[str] = None
tp: int = 1
dp: int = 1
pp: int = 1
device_num: int = None
attn_tp_size: int = None
attn_dp_size: int = None
mlp_tp_size: int = None
mlp_dp_size: int = None
outer_dp_size: int = None
nnodes: int = 1
node_rank: int = 0
ngpus_per_node: Optional[int] = None
devices: List[int] = None
session_len: Optional[int] = None
max_batch_size: int = None
cache_max_entry_count: float = 0.8
Expand Down
47 changes: 34 additions & 13 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import json
import math
import os
import os.path as osp
import sys
from collections import defaultdict
Expand Down Expand Up @@ -84,14 +85,23 @@ def complete_parallel_config(cfg: TurbomindEngineConfig):


def update_parallel_config(cfg: TurbomindEngineConfig):
if cfg.nnodes > 1:
assert cfg.ngpus_per_node is not None or cfg.devices is not None
cfg.devices = cfg.devices or list(range(cfg.ngpus_per_node))
cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices)
cfg.device_num = cfg.device_num or len(cfg.devices) * cfg.nnodes

if not complete_parallel_config(cfg):
total = cfg.dp * cfg.tp
total = cfg.dp * cfg.tp * cfg.pp
if not cfg.device_num:
count = torch.cuda.device_count()
if total < count:
count = total
cfg.device_num = count
assert cfg.device_num % cfg.pp == 0
assert total % cfg.device_num == 0
if cfg.dp > 1:
total = cfg.device_num // cfg.pp
overlap = total // cfg.device_num
attn_dp_size = overlap
mlp_tp_size = overlap
Expand All @@ -102,10 +112,19 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
cfg.mlp_dp_size = 1
cfg.mlp_tp_size = mlp_tp_size * inner_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size * cfg.pp == cfg.device_num
assert cfg.outer_dp_size > 0 and cfg.attn_tp_size > 0
cfg.devices = cfg.devices or list(range(cfg.device_num))


# update devices
if cfg.nnodes == 1:
cfg.devices = cfg.devices if cfg.devices else list(range(cfg.device_num))
cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices)
# for simplicity, each node has dp
assert cfg.outer_dp_size * cfg.attn_dp_size % cfg.nnodes == 0


class TurboMind:
"""LMDeploy's inference engine.

Expand Down Expand Up @@ -141,8 +160,15 @@ def __init__(self,
f' greater than 0, but got {_engine_config.max_batch_size}'

update_parallel_config(_engine_config)

self.gpu_count = _engine_config.device_num
if _engine_config.nnodes > 1 and _engine_config.node_rank == 0:
from torch.distributed import TCPStore
master_addr = os.environ.get('LMDEPLOY_DP_MASTER_ADDR')
master_port = os.environ.get('LMDEPLOY_DP_MASTER_PORT')
assert master_addr is not None and master_port is not None, \
'LMDEPLOY_DP_MASTER_ADDR and LMDEPLOY_DP_MASTER_PORT should be set when using multi-node'
self.store = TCPStore(host_name=master_addr, port=int(master_port), is_master=True)

self.gpu_count = len(_engine_config.devices)
self.devices = _engine_config.devices

self.tokenizer = tokenizer
Expand Down Expand Up @@ -196,10 +222,8 @@ def _create_engine(self):
def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

# TODO: support mpi
self.node_id = 0
self.node_num = 1
torch.cuda.synchronize()
engine_cfg = self.config_dict['engine_config']
self.node_id = engine_cfg['node_rank']

# create weight
def _create_weight_func(device_id):
Expand Down Expand Up @@ -394,6 +418,8 @@ def close(self):
del self._export_iter
if self.model_comm is not None:
self.model_comm = None
if hasattr(self, 'store'):
del self.store

def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
Expand Down Expand Up @@ -500,11 +526,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id

self.node_id = tm_model.node_id
self.gpu_count = tm_model.gpu_count

self.session_len = tm_model.session_len

# create model instances
lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)
self._model_inst = None if lazy_init else self._create_model_instance(0)
Expand Down
8 changes: 8 additions & 0 deletions src/turbomind/comm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ if (BUILD_MULTI_GPU)
target_link_libraries(device_comm INTERFACE nccl_comm)
endif ()

add_subdirectory(gloo)
target_link_libraries(host_comm INTERFACE gloo_comm)

add_library(serialize STATIC serialize.cc)
target_link_libraries(serialize PRIVATE core)
set_property(TARGET serialize PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(host_comm INTERFACE serialize)

if (BUILD_TEST)
add_executable(test_comm test_comm.cu)
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
Expand Down
10 changes: 10 additions & 0 deletions src/turbomind/comm/device_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ class DeviceCommImpl {
{
throw std::runtime_error("not implemented");
}

virtual void Send(const void* sendbuff, size_t count, DataType type, int dst, int group, cudaStream_t stream)
{
throw std::runtime_error("not implemented");
}

virtual void Recv(void* recvbuff, size_t count, DataType type, int src, int group, cudaStream_t stream)
{
throw std::runtime_error("not implemented");
}
};

class DeviceComm {
Expand Down
29 changes: 29 additions & 0 deletions src/turbomind/comm/gloo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.8)

include(FetchContent)
FetchContent_Declare(
gloo
GIT_REPOSITORY https://github.com/pytorch/gloo.git
GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4
)

# some settings of gloo,
set(GLOO_INSTALL OFF CACHE BOOL "" FORCE)
set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE)
set(USE_NCCL OFF)
set(BUILD_TEST OFF)
FetchContent_MakeAvailable(gloo)

# gloo build doesn't add include directories as a target property...
target_include_directories(gloo PUBLIC
$<BUILD_INTERFACE:${gloo_SOURCE_DIR}>
$<BUILD_INTERFACE:${gloo_BINARY_DIR}> # config.h generated at cmake config time
)

add_library(gloo_comm STATIC
gloo_comm.cc
tcp_store.cc
)
set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(gloo_comm PUBLIC gloo logger)
Loading