Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def add_parser_proxy():
parser.add_argument('--dummy-prefill', action='store_true', help='dummy prefill for performance profiler')
parser.add_argument('--routing-strategy',
type=str,
choices=['random', 'min_expected_latency', 'min_observed_latency'],
choices=['random', 'round_robin', 'min_expected_latency', 'min_observed_latency'],
default='min_expected_latency',
help='the strategy to dispatch requests to nodes')
parser.add_argument('--disable-cache-status',
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass

from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
from lmdeploy.pytorch.disagg.conn.protocol import MigrationContext

from .tokenizer import Tokenizer
from .utils import get_logger
Expand Down Expand Up @@ -114,7 +114,7 @@ class GenerationConfig:
# for disaggregation
with_cache: bool = False
preserve_cache: bool = False
migration_request: Optional[MigrationRequest] = None
migration_context: Optional[MigrationContext] = None

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/disagg/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod

from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
MigrationProtocol)
KVTransferProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment


Expand All @@ -17,7 +17,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
raise NotImplementedError

@abstractmethod
def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol):
return NotImplementedError

@abstractmethod
Expand Down
44 changes: 21 additions & 23 deletions lmdeploy/pytorch/disagg/backend/dlslime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
MigrationProtocol)
KVTransferProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment

logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -40,20 +40,20 @@ def __init__(self, init_request: DistServeInitRequest):
self.rank = init_request.rank
self.local_engine_config: DistServeEngineConfig = init_request.local_engine_config
self.remote_engine_config: DistServeEngineConfig = init_request.remote_engine_config
self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = {
MigrationProtocol.TCP: None,
MigrationProtocol.RDMA: None,
MigrationProtocol.NVLINK: None,
self.endpoint: Dict[KVTransferProtocol, RDMAEndpoint] = {
KVTransferProtocol.TCP: None,
KVTransferProtocol.RDMA: None,
KVTransferProtocol.NVLINK: None,
}
if init_request.protocol == MigrationProtocol.RDMA:
if init_request.kvtransfer_protocol == KVTransferProtocol.RDMA:
nics = available_nic()
device_name = nics[self.rank % len(nics)]
logger.info(f'use device {device_name} for kv migration')
self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(device_name=device_name,
self.endpoint[KVTransferProtocol.RDMA] = RDMAEndpoint(device_name=device_name,
ib_port=1,
link_type=init_request.rdma_config.link_type.name)
elif init_request.protocol == MigrationProtocol.NVLINK:
self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint()
elif init_request.kvtransfer_protocol == KVTransferProtocol.NVLINK:
self.endpoint[KVTransferProtocol.NVLINK] = NVLinkEndpoint()

def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
self.endpoint[register_mr_request.protocol].register_memory_region(register_mr_request.mr_key,
Expand All @@ -64,7 +64,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo):
self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info))

async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False):
async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False, proactive=True):
batch = [
DLSlimeAssignment(
mr_key=assign.mr_key,
Expand All @@ -74,20 +74,18 @@ async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False):
) for assign in assignment.batch
]

if not LMDEPLOY_USE_ASYNC_MIGRATION:
MAX_NUM_READ_BATCH = 4096
MAX_NUM_READ_BATCH = 4096

def split(batch: List[DLSlimeAssignment]):
batch_split = []
for i in range(0, len(batch), MAX_NUM_READ_BATCH):
batch_split.append(batch[i:i + MAX_NUM_READ_BATCH])
return batch_split
def split(batch: List[DLSlimeAssignment]):
batch_split = []
for i in range(0, len(batch), MAX_NUM_READ_BATCH):
batch_split.append(batch[i:i + MAX_NUM_READ_BATCH])
return batch_split

batch_splited = split(batch)
for b_split in batch_splited:
self.endpoint[assignment.protocol].read_batch(b_split)
else:
await read_batch_coroutine(self.endpoint[assignment.protocol], batch)
batch_splited = split(batch)
for b_split in batch_splited:
logger.error(b_split)
self.endpoint[assignment.protocol].write_batch(b_split)


@MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name)
Expand All @@ -103,7 +101,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest):
def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)

def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol):
return self.links[remote_engine_id].endpoint[protocol].endpoint_info

def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo):
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/disagg/backend/mooncake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import MigrationBackend, MooncakeEngineConfig
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo,
MigrationProtocol)
KVTransferProtocol)
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -222,7 +222,7 @@ def _migrate(self, assignment: MigrationAssignment):
logger.debug(f" Remote: 0x{remote_buffer_info['addr']:x} + {task.target_offset} = 0x{remote_addr:x}")
logger.debug(f' Session: {self.remote_url}')

result = self.engine.transfer_sync_read(
result = self.engine.transfer_sync_write(
self.remote_url,
local_addr,
remote_addr,
Expand All @@ -245,7 +245,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest):
def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request)

def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol):
return self.links[remote_engine_id].endpoint_info

def p2p_connect(self, remote_engine_id: str, connect_request: DistServeKVTransferEndpointInfo):
Expand Down
Loading