diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index ecd9fdae05..0210654498 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -158,7 +158,7 @@ def add_parser_proxy(): parser.add_argument('--dummy-prefill', action='store_true', help='dummy prefill for performance profiler') parser.add_argument('--routing-strategy', type=str, - choices=['random', 'min_expected_latency', 'min_observed_latency'], + choices=['random', 'round_robin', 'min_expected_latency', 'min_observed_latency'], default='min_expected_latency', help='the strategy to dispatch requests to nodes') parser.add_argument('--disable-cache-status', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 4bd68e8e0b..252410825b 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -8,7 +8,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import MigrationContext from .tokenizer import Tokenizer from .utils import get_logger @@ -114,7 +114,7 @@ class GenerationConfig: # for disaggregation with_cache: bool = False preserve_cache: bool = False - migration_request: Optional[MigrationRequest] = None + migration_context: Optional[MigrationContext] = None def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to diff --git a/lmdeploy/pytorch/disagg/backend/base.py b/lmdeploy/pytorch/disagg/backend/base.py index 7e7716dffc..a12a577d80 100644 --- a/lmdeploy/pytorch/disagg/backend/base.py +++ b/lmdeploy/pytorch/disagg/backend/base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, - MigrationProtocol) + KVTransferProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment @@ -17,7 +17,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage raise NotImplementedError @abstractmethod - def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol): return NotImplementedError @abstractmethod diff --git a/lmdeploy/pytorch/disagg/backend/dlslime.py b/lmdeploy/pytorch/disagg/backend/dlslime.py index 80257890b7..fb3d19b5cb 100644 --- a/lmdeploy/pytorch/disagg/backend/dlslime.py +++ b/lmdeploy/pytorch/disagg/backend/dlslime.py @@ -12,7 +12,7 @@ from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, - MigrationProtocol) + KVTransferProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment logger = get_logger('lmdeploy') @@ -40,20 +40,20 @@ def __init__(self, init_request: DistServeInitRequest): self.rank = init_request.rank self.local_engine_config: DistServeEngineConfig = init_request.local_engine_config self.remote_engine_config: DistServeEngineConfig = init_request.remote_engine_config - self.endpoint: Dict[MigrationProtocol, RDMAEndpoint] = { - MigrationProtocol.TCP: None, - MigrationProtocol.RDMA: None, - MigrationProtocol.NVLINK: None, + self.endpoint: Dict[KVTransferProtocol, RDMAEndpoint] = { + KVTransferProtocol.TCP: None, + KVTransferProtocol.RDMA: None, + KVTransferProtocol.NVLINK: None, } - if init_request.protocol == MigrationProtocol.RDMA: + if init_request.kvtransfer_protocol == KVTransferProtocol.RDMA: nics = available_nic() device_name = nics[self.rank % len(nics)] logger.info(f'use device {device_name} for kv migration') - self.endpoint[MigrationProtocol.RDMA] = RDMAEndpoint(device_name=device_name, + self.endpoint[KVTransferProtocol.RDMA] = RDMAEndpoint(device_name=device_name, ib_port=1, link_type=init_request.rdma_config.link_type.name) - elif init_request.protocol == MigrationProtocol.NVLINK: - self.endpoint[MigrationProtocol.NVLINK] = NVLinkEndpoint() + elif init_request.kvtransfer_protocol == KVTransferProtocol.NVLINK: + self.endpoint[KVTransferProtocol.NVLINK] = NVLinkEndpoint() def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): self.endpoint[register_mr_request.protocol].register_memory_region(register_mr_request.mr_key, @@ -64,7 +64,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo): self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info)) - async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False): + async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False, proactive=True): batch = [ DLSlimeAssignment( mr_key=assign.mr_key, @@ -74,20 +74,18 @@ async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False): ) for assign in assignment.batch ] - if not LMDEPLOY_USE_ASYNC_MIGRATION: - MAX_NUM_READ_BATCH = 4096 + MAX_NUM_READ_BATCH = 4096 - def split(batch: List[DLSlimeAssignment]): - batch_split = [] - for i in range(0, len(batch), MAX_NUM_READ_BATCH): - batch_split.append(batch[i:i + MAX_NUM_READ_BATCH]) - return batch_split + def split(batch: List[DLSlimeAssignment]): + batch_split = [] + for i in range(0, len(batch), MAX_NUM_READ_BATCH): + batch_split.append(batch[i:i + MAX_NUM_READ_BATCH]) + return batch_split - batch_splited = split(batch) - for b_split in batch_splited: - self.endpoint[assignment.protocol].read_batch(b_split) - else: - await read_batch_coroutine(self.endpoint[assignment.protocol], batch) + batch_splited = split(batch) + for b_split in batch_splited: + logger.error(b_split) + self.endpoint[assignment.protocol].write_batch(b_split) @MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name) @@ -103,7 +101,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest): def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request) - def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol): return self.links[remote_engine_id].endpoint[protocol].endpoint_info def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo): diff --git a/lmdeploy/pytorch/disagg/backend/mooncake.py b/lmdeploy/pytorch/disagg/backend/mooncake.py index e4ba7fbd5f..b5ee368cb3 100644 --- a/lmdeploy/pytorch/disagg/backend/mooncake.py +++ b/lmdeploy/pytorch/disagg/backend/mooncake.py @@ -10,7 +10,7 @@ from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl from lmdeploy.pytorch.disagg.config import MigrationBackend, MooncakeEngineConfig from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, - MigrationProtocol) + KVTransferProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment from lmdeploy.utils import get_logger @@ -222,7 +222,7 @@ def _migrate(self, assignment: MigrationAssignment): logger.debug(f" Remote: 0x{remote_buffer_info['addr']:x} + {task.target_offset} = 0x{remote_addr:x}") logger.debug(f' Session: {self.remote_url}') - result = self.engine.transfer_sync_read( + result = self.engine.transfer_sync_write( self.remote_url, local_addr, remote_addr, @@ -245,7 +245,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest): def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): self.links[register_mr_request.remote_engine_id].register_memory_region(register_mr_request) - def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): + def endpoint_info(self, remote_engine_id: int, protocol: KVTransferProtocol): return self.links[remote_engine_id].endpoint_info def p2p_connect(self, remote_engine_id: str, connect_request: DistServeKVTransferEndpointInfo): diff --git a/lmdeploy/pytorch/disagg/conn/engine_conn.py b/lmdeploy/pytorch/disagg/conn/engine_conn.py index a14c684f3b..19673c5b0a 100644 --- a/lmdeploy/pytorch/disagg/conn/engine_conn.py +++ b/lmdeploy/pytorch/disagg/conn/engine_conn.py @@ -1,19 +1,42 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import os -from typing import TYPE_CHECKING, Dict, List +import time +import functools +from typing import Any, Awaitable, Callable, Dict, List, Tuple, TypeAlias, TYPE_CHECKING from urllib.parse import urlparse import zmq import zmq.asyncio +import numpy as np + +import torch + from lmdeploy.logger import get_logger + +from lmdeploy.messages import GenerationConfig +from lmdeploy.pytorch.engine.request import ResponseType + +from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, - DistServeConnectionResponse, DistServeConnectionStatus, + DistServeConnectionResponse, DistServeStatus, DistServeDropConnectionRequest, DistServeEngineEndpointInfo, DistServeInitRequest, DistServeInitResponse, - DistServeKVTransferEndpointInfo) + DistServeKVTransferEndpointInfo, DistServeRecomputeRequest, + DistServeRecomputeResponse, DistServeFetchMetaRequest, + DistServeFetchMetaResponse, DistServeProactiveMigrationRequest, + DistServeProactiveMigrationResponse) +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch + from lmdeploy.pytorch.engine.executor.dist_utils import find_available_port +from lmdeploy.pytorch.engine.mp_engine.engine_instance_pool import EngineInstancePool +from lmdeploy.pytorch.messages import MessageStatus, HistoryTokenIds + +from lmdeploy.pytorch.engine.request import ResponseType +from lmdeploy.pytorch.messages import InferOutput, SchedulerSequence +from lmdeploy.messages import GenerationConfig + if TYPE_CHECKING: from lmdeploy.pytorch.engine.engine import Engine @@ -21,8 +44,18 @@ logger = get_logger('lmdeploy') -class EngineP2PConnection: +DistServeEngineConnCallResponse: TypeAlias = ( + DistServeFetchMetaRequest + | DistServeFetchMetaResponse + | DistServeProactiveMigrationRequest + | DistServeProactiveMigrationResponse + | DistServeCacheFreeRequest + | DistServeRecomputeRequest + | DistServeRecomputeResponse +) + +class EngineP2PConnection: def __init__(self, engine: 'Engine'): self.engine: Engine = engine self.p2p_conn_ctx: Dict[str, zmq.asyncio.Context] = {} @@ -30,6 +63,30 @@ def __init__(self, engine: 'Engine'): self.p2p_receiver: Dict[str, zmq.asyncio.Socket] = {} self.use_unique_kvtransfer_engine = os.environ.get('LMDEPLOY_USE_UNIQUE_KVTRANSFER_ENGINE', False) + self.recomputation_conn_pool = EngineInstancePool(self.engine) + + self.handle_migration_event = asyncio.Event() + self.handle_meta_migration_event = asyncio.Event() + self.handle_recomputation_event = asyncio.Event() + + self.release_lock = asyncio.Lock() + + self.resp_que: asyncio.Queue[InferOutput] = None + self.has_runable_event: asyncio.Event = None + + def _status_jump(self, msg: SchedulerSequence): + _set_status = lambda msg, status: self.engine.scheduler._set_message_status(msg, status) + _distserve_state_machine = { + MessageStatus.META_MIGRATION_WAITING: MessageStatus.META_MIGRATION_RUNNING, + MessageStatus.META_MIGRATION_RUNNING: MessageStatus.MIGRATION_WAITING, + MessageStatus.MIGRATION_WAITING: MessageStatus.MIGRATION_RUNNING, + MessageStatus.MIGRATION_RUNNING: MessageStatus.MIGRATION_DONE, + MessageStatus.RECOMPUTION_PREEMPTION: MessageStatus.REMOTE_RECOMPUTING, + MessageStatus.REMOTE_RECOMPUTING: MessageStatus.REMOTE_RECOMPUTED + } + if msg.status in _distserve_state_machine: + # TODO (Jimy): handle TO_BE_MIGRATED + _set_status(msg, _distserve_state_machine[msg.status]) async def p2p_initialize(self, init_request: DistServeInitRequest): ctx = zmq.asyncio.Context(2) @@ -49,7 +106,7 @@ async def p2p_initialize(self, init_request: DistServeInitRequest): return DistServeInitResponse(engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_address), kvtransfer_endpoint_info=kvtransfer_endpoint_info, - status=DistServeConnectionStatus.SUCCESS) + status=DistServeStatus.SUCCESS) def p2p_connect(self, conn_request: DistServeConnectionRequest): self.p2p_receiver[conn_request.remote_engine_id].connect(conn_request.remote_engine_endpoint_info.zmq_address) @@ -57,30 +114,212 @@ def p2p_connect(self, conn_request: DistServeConnectionRequest): conn_request=conn_request.remote_kvtransfer_endpoint_info) event_loop = asyncio.get_event_loop() event_loop.create_task(self.handle_zmq_recv(conn_request.remote_engine_id)) - return DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) + return DistServeConnectionResponse(status=DistServeStatus.SUCCESS) def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): # TODO (JimyMa): drop RDMA Connection - self.zmq_disconnect(drop_conn_request.remote_engine_id) + # self.zmq_disconnect(drop_conn_request.remote_engine_id) return {'success': True} - async def zmq_send(self, remote_engine_id: str, remote_session_id: int): - await self.p2p_sender[remote_engine_id].send_pyobj( - DistServeCacheFreeRequest(remote_engine_id=remote_engine_id, remote_session_id=remote_session_id)) + async def init_engine_conn_loop(self, resp_que, has_runable_event): + self.resp_que = resp_que + self.has_runable_event = has_runable_event + event_loop = asyncio.get_event_loop(resp_que, has_runable_event) + loop_tasks = [] + loop_migration = event_loop.create_task( + self.engine_conn._handle_migration(resp_que, has_runable_event=has_runable_event), + name='MainLoopMigration', + ) + loop_meta_migration = event_loop.create_task(self.engine_conn.handle_meta_migration(), name="EngineConnHandleMetaMigration") + loop_recomputation = event_loop.create_task(self.engine_conn.handle_recomputation(), name="HandleRecomputation") + loop_tasks.extend([loop_migration, loop_meta_migration, loop_recomputation]) + return loop_tasks - async def handle_zmq_recv(self, remote_engine_id: str): + async def handle_meta_migration(self): while True: - req: DistServeCacheFreeRequest = await self.p2p_receiver[remote_engine_id].recv_pyobj() - if isinstance(req, DistServeCacheFreeRequest): - session_id = req.remote_session_id + await self.handle_meta_migration_event.wait() + self.handle_meta_migration_event.clear() + meta_migration_waiting_seqs = self.engine.scheduler.meta_migration_waiting + for meta_migration_waiting_seq in meta_migration_waiting_seqs: + await self.zmq_send( + meta_migration_waiting_seq.migration_context.prefill_engine_id, + DistServeFetchMetaRequest(migration_context=meta_migration_waiting_seq.migration_context) + ) + + async def handle_recomputation(self): + while True: + await self.handle_recomputation_event.wait() + self.handle_recomputation_event.clear() + preempted_seqs = self.engine.scheduler.recomputation_preemption + for preempted_seq in preempted_seqs: + preempted_seq.migration_context.token_ids = preempted_seq.token_ids + await self.zmq_send( + preempted_seq.migration_context.prefill_engine_id, + DistServeRecomputeRequest(migration_context=preempted_seq.migration_context) + ) + + @torch.inference_mode() + async def _handle_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): + """Async loop migration.""" + while True: + migration_running = self.engine.scheduler._schedule_migration() + if not migration_running and not self.engine.scheduler.has_migration_waiting(): + await self.handle_migration_event.wait() + elif migration_running: + self.handle_migration_event.clear() + for msg in migration_running: + migration_context = msg.migration_context + migration_context.decode_block_ids = list(self.engine.scheduler.block_manager.get_block_table(msg=msg)) + await self.zmq_send(migration_context.prefill_engine_id, DistServeProactiveMigrationRequest(migration_context=migration_context)) + else: + # release coroutine for decoding + await asyncio.sleep(.5) + + async def zmq_send(self, remote_engine_id: str, req: DistServeEngineConnCallResponse): + _get_msg = lambda session_id: list(self.engine.scheduler.sessions[session_id].sequences.values())[0] + + def _send_preprocess(func: Callable[[], Awaitable[Any]]) -> Callable[[], Awaitable[Any]]: + @functools.wraps(func) + async def wrapper() -> Any: + migration_context = req.migration_context + if self.engine.engine_config.role == EngineRole.Decode: + self._status_jump(_get_msg(migration_context.decode_session_id)) + return await func() + return wrapper + + @_send_preprocess + async def _send_impl(): + logger.error(f"Sending, {req=}") + await self.p2p_sender[remote_engine_id].send_pyobj(req) + + await _send_impl() + + async def handle_zmq_recv(self, remote_engine_id: str): + _get_msg = lambda session_id: list(self.engine.scheduler.sessions[session_id].sequences.values())[0] + + async def _handle_fetch_migration_context_call(req: DistServeFetchMetaRequest): + logger.error("handle fetch migration context call") + migration_context = req.migration_context + msg = _get_msg(migration_context.prefill_session_id) + migration_context.token_ids = msg.all_ids.tolist() + migration_context.prefill_block_ids = list(self.engine.scheduler.block_manager.get_block_table(msg=msg)) + await self.zmq_send( + migration_context.decode_engine_id, + DistServeFetchMetaResponse(migration_context=migration_context, status=DistServeStatus.SUCCESS) + ) + + async def _handle_fetch_migration_context_resp(req: DistServeFetchMetaResponse): + migration_context = req.migration_context + msg = _get_msg(migration_context.decode_session_id) + msg.history_cache = HistoryTokenIds(np.array(migration_context.token_ids[:-1])) + msg.migration_context = migration_context + msg.__post_init__() + self._status_jump(msg) + self.handle_migration_event.set() + + async def _handle_remote_preemption_call(req: DistServeRecomputeRequest): + migration_context = req.migration_context + async with self.recomputation_conn_pool.instance() as instance: + gen_config = GenerationConfig( + max_new_tokens=1, + with_cache=True, + preserve_cache=True + ) + if migration_context.prefill_session_id in self.engine.scheduler.sessions: + self.engine.scheduler.end_session(session_id=migration_context.prefill_session_id) + resp = await instance.async_infer(migration_context.prefill_session_id, req.token_ids, gen_config=gen_config) + migration_context.prefill_block_ids = resp.cache_block_ids + migration_context.token_ids = resp.token_ids + recompute_resp = DistServeRecomputeResponse( + migration_context=migration_context, + status=DistServeStatus.SUCCESS + ) + logger.error(f"{self.p2p_sender[migration_context.decode_engine_id]=}") + + await self.zmq_send(migration_context.decode_engine_id, recompute_resp) + + async def _handle_remote_preemption_resp(req: DistServeRecomputeResponse): + migration_context = req.migration_context + msg = _get_msg(migration_context.decode_session_id) + msg.migration_context = migration_context + logger.error(f"{migration_context=}") + self._status_jump(msg) + self.handle_migration_event.set() + + async def _handle_proactive_migration_call(req: DistServeProactiveMigrationRequest): + migration_context = req.migration_context + + def _handle_cache_free(): + session_id = req.migration_context.prefill_session_id if session_id in self.engine.scheduler.sessions: self.engine.scheduler.end_session(session_id=session_id) else: logger.error(f'invalid free, {remote_engine_id}, {session_id}') - else: - raise ValueError(f'Unsupported zmq request {type(req)}') + + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + logger.error(list(zip(migration_context.prefill_block_ids, migration_context.decode_block_ids))) + migration_execution_requests.append(( + migration_context.decode_engine_id, + list(zip(migration_context.decode_block_ids, migration_context.prefill_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=migration_context.protocol, + requests=migration_execution_requests) + msg = _get_msg(migration_context.prefill_session_id) + logger.info(f'migrating session: {msg.session_id} begin') + migration_context.time_stamp.migration_begine = time.time() + await self.engine.executor.migrate(migration_inputs) + migration_context.time_stamp.migration_end = time.time() + logger.info(f'migrating session: {msg.session_id} done') + _handle_cache_free() + migration_resp = DistServeProactiveMigrationResponse( + migration_context=migration_context, + status=DistServeStatus.SUCCESS + ) + await self.zmq_send(migration_context.decode_engine_id, migration_resp) + + async def _handle_proactive_migration_resp(req: DistServeProactiveMigrationResponse): + # generate output + migration_context = req.migration_context + outputs: Dict[int, InferOutput] = dict() + msg = _get_msg(migration_context.decode_session_id) + msg.migration_context = migration_context + msg.resp.type = ResponseType.SUCCESS + token_ids = [migration_context.token_ids[-1]] + out = InferOutput( + session_id=migration_context.decode_session_id, + resp=msg.resp, + finish=False, + token_ids=np.array(token_ids) + ) + outputs[migration_context.decode_session_id] = out + self.engine.update_running_migration([msg], np.array([token_ids]), [False], [None]) + self.resp_que.put_nowait(outputs) + self._status_jump(msg) + self.has_runable_event.set() + + method_fn: Dict[str, Awaitable[None]] = {} + def _register_method(primitive: DistServeEngineConnCallResponse, fn: Callable[[DistServeConnectionResponse], None]): + method_fn[primitive] = fn + + _register_method(DistServeFetchMetaRequest.__name__, _handle_fetch_migration_context_call) + _register_method(DistServeFetchMetaResponse.__name__, _handle_fetch_migration_context_resp) + _register_method(DistServeRecomputeRequest.__name__, _handle_remote_preemption_call) + _register_method(DistServeRecomputeResponse.__name__, _handle_remote_preemption_resp) + _register_method(DistServeProactiveMigrationRequest.__name__, _handle_proactive_migration_call) + _register_method(DistServeProactiveMigrationResponse.__name__, _handle_proactive_migration_resp) + + while True: + logger.error("starting") + req: DistServeEngineConnCallResponse = await self.p2p_receiver[remote_engine_id].recv_pyobj() + logger.error(f"recv: {req=}, {req.__class__.__name__}") + try: + await method_fn[req.__class__.__name__](req) + except KeyError: + logger.error(f'Unsupported zmq request {type(req)}') + raise KeyError async def zmq_disconnect(self, remote_engine_id: str): - self.p2p_receiver[remote_engine_id].close() - self.p2p_sender[remote_engine_id].close() - self.p2p_conn_ctx[remote_engine_id].term() + async with self.release_lock: + self.p2p_receiver[remote_engine_id].close() + self.p2p_sender[remote_engine_id].close() + self.p2p_conn_ctx[remote_engine_id].term() diff --git a/lmdeploy/pytorch/disagg/conn/protocol.py b/lmdeploy/pytorch/disagg/conn/protocol.py index aa47789497..f8b4c8054d 100644 --- a/lmdeploy/pytorch/disagg/conn/protocol.py +++ b/lmdeploy/pytorch/disagg/conn/protocol.py @@ -8,7 +8,7 @@ DistServeTCPConfig) -class MigrationProtocol(enum.Enum): +class KVTransferProtocol(enum.Enum): """Migration Transport Protocol. Attributes: @@ -24,7 +24,7 @@ class MigrationProtocol(enum.Enum): NVLINK = enum.auto() -class DistServeConnectionStatus(enum.Enum): +class DistServeStatus(enum.Enum): # TODO(JimyMa): Add more connection failure handler SUCCESS = enum.auto() FAIL = enum.auto() @@ -37,7 +37,7 @@ class DistServeInitRequest(BaseModel): remote_engine_id: str remote_engine_config: DistServeEngineConfig - protocol: MigrationProtocol + kvtransfer_protocol: KVTransferProtocol rank: Optional[int] = None @@ -51,12 +51,12 @@ class DistServeEngineEndpointInfo(BaseModel): class DistServeKVTransferEndpointInfo(BaseModel): - protocol: MigrationProtocol + protocol: KVTransferProtocol endpoint_info: str class DistServeInitResponse(BaseModel): - status: DistServeConnectionStatus + status: DistServeStatus # the control plane initialization feedback engine_endpoint_info: DistServeEngineEndpointInfo # the KVCache Transfer initialization feedback @@ -67,32 +67,74 @@ class DistServeInitResponse(BaseModel): class DistServeConnectionRequest(BaseModel): - protocol: MigrationProtocol + protocol: KVTransferProtocol remote_engine_id: str remote_engine_endpoint_info: DistServeEngineEndpointInfo remote_kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] +class DistServeDropConnectionRequest(BaseModel): + engine_id: str + remote_engine_id: str + + class DistServeConnectionResponse(BaseModel): - status: DistServeConnectionStatus + status: DistServeStatus -class MigrationRequest(BaseModel): - protocol: MigrationProtocol +class MigrationTimeStamp(BaseModel): + arrive_time: Optional[float] = None + migration_begine: Optional[float] = None + migration_end: Optional[float] = None + + remote_recomputation_begin: Optional[List[float]] = None + remote_recomputation_end: Optional[List[float]] = None - remote_engine_id: str - remote_session_id: int - remote_token_id: int - remote_block_ids: List[int] + +class MigrationContext(BaseModel): + protocol: KVTransferProtocol + + decode_engine_id: str + decode_session_id: Optional[int] + decode_block_ids: Optional[List[int]] + + prefill_engine_id: str + prefill_session_id: int + prefill_block_ids: List[int] + + token_ids: Optional[List[int]] = None + + time_stamp: Optional[MigrationTimeStamp] = None is_dummy_prefill: bool = False +class DistServeFetchMetaRequest(BaseModel): + migration_context: MigrationContext + + +class DistServeFetchMetaResponse(BaseModel): + migration_context: MigrationContext + status: DistServeStatus + + +class DistServeProactiveMigrationRequest(BaseModel): + migration_context: MigrationContext + + +class DistServeProactiveMigrationResponse(BaseModel): + migration_context: MigrationContext + status: DistServeStatus + + class DistServeCacheFreeRequest(BaseModel): - remote_engine_id: str - remote_session_id: int + migration_context: MigrationContext -class DistServeDropConnectionRequest(BaseModel): - engine_id: str - remote_engine_id: str +class DistServeRecomputeRequest(BaseModel): + migration_context: MigrationContext + + +class DistServeRecomputeResponse(BaseModel): + migration_context: MigrationContext + status: DistServeStatus diff --git a/lmdeploy/pytorch/disagg/conn/proxy_conn.py b/lmdeploy/pytorch/disagg/conn/proxy_conn.py index a07d281248..f2691b6864 100644 --- a/lmdeploy/pytorch/disagg/conn/proxy_conn.py +++ b/lmdeploy/pytorch/disagg/conn/proxy_conn.py @@ -162,7 +162,7 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): # Step 2. Construct Initialize Configuration prefill_init_req = DistServeInitRequest( - protocol=conn_req.protocol, + kvtransfer_protocol=conn_req.protocol, local_engine_id=conn_req.p_url, local_engine_config=prefill_engine_config, remote_engine_id=conn_req.d_url, @@ -171,7 +171,7 @@ async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): nvlink_config=conn_req.nvlink_config, ) decode_init_req = DistServeInitRequest( - protocol=conn_req.protocol, + kvtransfer_protocol=conn_req.protocol, local_engine_id=conn_req.d_url, local_engine_config=decode_engine_config, remote_engine_id=conn_req.p_url, @@ -286,7 +286,7 @@ def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConn logger.warning('cache block gc triggered.') try: for session_id in self.migration_session_shelf[(left, right)]: - cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id)) + cache_free(left, DistServeCacheFreeRequest(prefill_engine_id=left, session_id=session_id)) except Exception as e: logger.warning(f'gc error, ErrorMsg: {str(e)}') diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py index 9dac0b0391..94b9462aad 100644 --- a/lmdeploy/pytorch/disagg/messages.py +++ b/lmdeploy/pytorch/disagg/messages.py @@ -4,13 +4,13 @@ from pydantic import BaseModel from lmdeploy.pytorch.disagg.config import DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig -from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol +from lmdeploy.pytorch.disagg.conn.protocol import KVTransferProtocol class MigrationExecutionBatch(BaseModel): """Input of the Migration.""" - protocol: MigrationProtocol + protocol: KVTransferProtocol requests: List[Tuple[str, List[Tuple[int, int]]]] = [] @@ -24,7 +24,7 @@ class AssignmentInstruct(BaseModel): class MigrationAssignment(BaseModel): """Migration Assignment.""" - protocol: MigrationProtocol + protocol: KVTransferProtocol remote_engine_id: str batch: List[AssignmentInstruct] @@ -32,14 +32,14 @@ class MigrationAssignment(BaseModel): class PDConnectionMessage(BaseModel): p_url: str d_url: str - protocol: MigrationProtocol = MigrationProtocol.RDMA + protocol: KVTransferProtocol = KVTransferProtocol.RDMA tcp_config: Optional[DistServeTCPConfig] = None rdma_config: Optional[DistServeRDMAConfig] = None nvlink_config: Optional[DistServeNVLinkConfig] = None class DistServeRegisterMRMessage(BaseModel): - protocol: MigrationProtocol + protocol: KVTransferProtocol remote_engine_id: str mr_key: str diff --git a/lmdeploy/pytorch/disagg/migration_engine.py b/lmdeploy/pytorch/disagg/migration_engine.py new file mode 100644 index 0000000000..9e3819187d --- /dev/null +++ b/lmdeploy/pytorch/disagg/migration_engine.py @@ -0,0 +1,6 @@ +import torch + + +class MigrationEngine: + def __init__(self): + self.cache_engine \ No newline at end of file diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index fe61572a20..d194280e33 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -324,18 +324,18 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistSe for i, t in enumerate(self.full_gpu_cache): if t.numel() == 0: continue - register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol, + register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.kvtransfer_protocol, remote_engine_id=migration_init_request.remote_engine_id, mr_key=str(i), addr=t.data_ptr(), offset=t.storage_offset(), length=t.numel() * t.itemsize) self.migration_backend_impl.register_memory_region(register_mr_request) - return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, + return DistServeKVTransferEndpointInfo(protocol=migration_init_request.kvtransfer_protocol, endpoint_info=json.dumps( self.migration_backend_impl.endpoint_info( migration_init_request.remote_engine_id, - migration_init_request.protocol))) + migration_init_request.kvtransfer_protocol))) def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index da5c5ad719..40e823dd83 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -14,13 +14,12 @@ from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, - DistServeInitRequest) -from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch + DistServeInitRequest, MigrationContext) from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer from ..adapter.adapter import AdapterManager from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig -from ..messages import MessageStatus, SchedulerSequence +from ..messages import MessageStatus, SchedulerSequence, InferOutput from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from .engine_checker import EngineChecker @@ -35,25 +34,6 @@ _EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) -@dataclass -class InferOutput: - """The output of the model inference.""" - - session_id: int - resp: Response - token_ids: List[int] - meta: Any = None - finish: bool = False - logits: torch.Tensor = None - - # send cache blocks back for migration in Disaggregated LLM Serving - # when Prefill Engine is Done. - cache_block_ids: List[int] = None - - # for logging - req_metrics: RequestMetrics = None - - def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): """Tensorlize block_offsets.""" # copy on numpy is faster than torch.nn.utils.rnn.pad_sequence @@ -235,6 +215,8 @@ def __init__(self, engine: 'Engine'): def do_prefill_dp(self): if self.role == EngineRole.Prefill: return True + elif self.role == EngineRole.Decode: + return False scheduler = self.scheduler @@ -600,23 +582,26 @@ def __update_max_new_tokens(msg): sampling_param = req.data['sampling_param'] return_logits = sampling_param.out_logits if len(sess.sequences) == 0: - migration_request = req.data.get('migration_request') - assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') + migration_context: MigrationContext = req.data.get('migration_context') + if not migration_context: + assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') + sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, adapter_name=req.data['adapter_name'], return_logits=return_logits, multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings', ), - migration_request=migration_request, + migration_context=migration_context, resp_cache=req.data.get('with_cache'), preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) __update_max_new_tokens(msg) scheduler.add_sequence(msg) - if migration_request: - self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) - self.migration_event.set() + if migration_context: + migration_context.decode_session_id =session_id + self.scheduler._set_message_status(msg, MessageStatus.META_MIGRATION_WAITING) + self.engine_conn.handle_meta_migration_event.set() else: msg = next(iter(sess.sequences.values())) msg.update_token_ids( @@ -801,8 +786,6 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, if model_metas is None: model_metas = [None] * len(running) for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.MIGRATION_LOCKED: - continue update_token = token # fill token @@ -904,7 +887,7 @@ def __need_logits(seqs: SeqList): scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval, engine_role=self.engine_config.role) if enable_empty and len(scheduler_output.running) == 0: return None @@ -912,7 +895,7 @@ def __need_logits(seqs: SeqList): # schedule decoding if no valid prefill reqs. if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill: prefill = False - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval, engine_role=self.engine_config.role) num_loops = 1 if prefill else prefill_interval running = scheduler_output.running @@ -1007,64 +990,6 @@ def p2p_connect(self, conn_request: DistServeConnectionRequest): async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): return self.engine_conn.p2p_drop_connect(drop_conn_request) - @torch.inference_mode() - async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): - """Async loop migration.""" - while True: - migration_running = self.scheduler._schedule_migration() - if not migration_running and not self.scheduler.has_migration_waiting(): - await self.migration_event.wait() - elif migration_running: - self.migration_event.clear() - for msg in migration_running: - migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] - migration_request = msg.migration_request - prefill_block_ids = migration_request.remote_block_ids - decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) - - if not migration_request.is_dummy_prefill: - assert len(prefill_block_ids) == len(decode_block_ids), ( - f'#prefill block ids ({len(prefill_block_ids)}) must equal to ' - f'#decode block ids ({len(decode_block_ids)})' - f'all id length: {len(msg.num_token_ids)}') - migration_execution_requests.append(( - migration_request.remote_engine_id, - list(zip(prefill_block_ids, decode_block_ids)), - )) - migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, - requests=migration_execution_requests) - logger.info(f'migrating session: {msg.session_id} begin') - await self.executor.migrate(migration_inputs) - logger.info(f'migrating session: {msg.session_id} done') - await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id, - remote_session_id=migration_request.remote_session_id) - - # generate output - outputs: Dict[int, InferOutput] = dict() - self.scheduler.lock_running_migration(migration_running) - for _, msg in enumerate(migration_running): - session_id = msg.session_id - msg.resp.type = ResponseType.SUCCESS - token_ids = [msg.migration_request.remote_token_id] - # MUST be a wall-clock time - new_token_timestamp = time.time() - req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) - out = InferOutput( - session_id=session_id, - resp=msg.resp, - finish=False, - token_ids=np.array(token_ids), - metrics_info=req_metrics, - ) - outputs[session_id] = out - self.update_running_migration([msg], np.array([token_ids]), [False], [None]) - resp_que.put_nowait(outputs) - self.scheduler.unlock_running_migration(migration_running) - has_runable_event.set() - else: - # release coroutine for decoding - await asyncio.sleep(.5) - @torch.inference_mode() async def _async_loop_main( self, @@ -1090,6 +1015,8 @@ async def _async_loop_main( scheduler.collect_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() + if self.scheduler.has_recomputation_preempted: + self.engine_conn.handle_recomputation_event.set() if next_running is None: # TODO (JimyMa): add watermark check event instead of async sleep. # self.perfill_watermark_event.wait() @@ -1109,7 +1036,7 @@ async def _async_loop_main( for idx in range(num_loops): # pre-forward before get last token - if idx == num_loops - 1: + if idx == num_loops - 1 and self.engine_config.role == EngineRole.Hybrid: scheduler.collect_migration_done() forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() @@ -1191,11 +1118,8 @@ async def async_loop(self): if self.engine_config.role != EngineRole.Hybrid: logger.info('Starting async task MigrationLoop.') - loop_migration = event_loop.create_task( - self._async_loop_migration(resp_que, has_runable_event=has_runable_event), - name='MainLoopMigration', - ) - loop_tasks.append(loop_migration) + engine_conn_tasks = self.engine_conn.init_engine_conn_loop(resp_que, has_runable_event) + loop_tasks.extend(engine_conn_tasks) # binding done callback self._add_loop_tasks_done_callback(loop_tasks) diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index f937c59398..ba803ff1f6 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -72,7 +72,7 @@ def cancel(req_sender: RequestSender, session_id: int): class EngineInstance: - """Instance of TurboMind. + """Instance of PyTorch. Args: engine (Engine): engine @@ -137,7 +137,7 @@ async def async_stream_infer(self, sampling_param=sampling_param, adapter_name=adapter_name, input_multimodals=multimodal, - migration_request=gen_config.migration_request, + migration_context=gen_config.migration_context, with_cache=gen_config.with_cache, preserve_cache=gen_config.preserve_cache, ) diff --git a/lmdeploy/pytorch/engine/mp_engine/engine_instance_pool.py b/lmdeploy/pytorch/engine/mp_engine/engine_instance_pool.py new file mode 100644 index 0000000000..7219db7e48 --- /dev/null +++ b/lmdeploy/pytorch/engine/mp_engine/engine_instance_pool.py @@ -0,0 +1,53 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.engine import Engine + from lmdeploy.pytorch.engine.engine_instance import EngineInstance + + +class EngineInstancePool: + """Engine Instance Pool.""" + + def __init__(self, engine): + self.engine: "Engine" = engine + self.num_instance = self.engine.engine_config.max_batch_size + self.pool: asyncio.Queue["EngineInstance"] = None + + def create_instance_pool(self, num_instance: int): + """Create instance pool.""" + pool = asyncio.Queue(maxsize=num_instance) + for _ in range(num_instance): + instance = self.engine.create_instance() + pool.put_nowait(instance) + return pool + + @asynccontextmanager + async def instance(self): + """Get an instance from the pool.""" + # lazy create pool + if self.pool is None: + self.pool = self.create_instance_pool(self.num_instance) + instance = await self.pool.get() + try: + yield instance + finally: + self.pool.put_nowait(instance) + + async def async_end(self, session_id: int): + """End the given session.""" + async with self.instance() as instance: + return await instance.async_end(session_id) + + async def async_cancel(self, session_id: int): + """Stop current streaming inference.""" + async with self.instance() as instance: + return await instance.async_cancel(session_id) + + async def async_stream_infer(self, *args, **kwargs): + """Send stream inference request.""" + async with self.instance() as instance: + async for result in instance.async_stream_infer(*args, **kwargs): + yield result \ No newline at end of file diff --git a/lmdeploy/pytorch/engine/mp_engine/mp_engine.py b/lmdeploy/pytorch/engine/mp_engine/mp_engine.py index 4100814be8..9460b79541 100644 --- a/lmdeploy/pytorch/engine/mp_engine/mp_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/mp_engine.py @@ -2,7 +2,7 @@ import asyncio import pickle import signal -from contextlib import asynccontextmanager + from typing import TYPE_CHECKING import torch.multiprocessing as mp @@ -12,6 +12,9 @@ DistServeInitRequest) from lmdeploy.utils import get_logger +from .engine_instance_pool import EngineInstancePool + + logger = get_logger('lmdeploy') if TYPE_CHECKING: @@ -28,52 +31,6 @@ def cancel_async_tasks(loop: asyncio.AbstractEventLoop): loop.close() -class EngineInstancePool: - """Engine Instance Pool.""" - - def __init__(self, engine): - from lmdeploy.pytorch.engine import Engine - self.engine: Engine = engine - self.num_instance = self.engine.engine_config.max_batch_size - self.pool = None - - def create_instance_pool(self, num_instance: int): - """Create instance pool.""" - pool = asyncio.Queue(maxsize=num_instance) - for _ in range(num_instance): - instance = self.engine.create_instance() - pool.put_nowait(instance) - return pool - - @asynccontextmanager - async def instance(self): - """Get an instance from the pool.""" - # lazy create pool - if self.pool is None: - self.pool = self.create_instance_pool(self.num_instance) - instance = await self.pool.get() - try: - yield instance - finally: - self.pool.put_nowait(instance) - - async def async_end(self, session_id: int): - """End the given session.""" - async with self.instance() as instance: - return await instance.async_end(session_id) - - async def async_cancel(self, session_id: int): - """Stop current streaming inference.""" - async with self.instance() as instance: - return await instance.async_cancel(session_id) - - async def async_stream_infer(self, *args, **kwargs): - """Send stream inference request.""" - async with self.instance() as instance: - async for result in instance.async_stream_infer(*args, **kwargs): - yield result - - class MPEngine: def __init__(self, model_path: str, tokenizer: object, engine_config: PytorchEngineConfig = None, **kwargs) -> None: diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 442e7146f4..abe30607bc 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -5,10 +5,13 @@ from typing import Any, Dict, List, Optional import numpy as np + +import torch from torch import Tensor -from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor -from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.request import Response +from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor, RequestMetrics +from lmdeploy.pytorch.disagg.conn.protocol import MigrationContext from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger @@ -35,6 +38,25 @@ def move_position(self, offset: int = 0): return self +@dataclass +class InferOutput: + """The output of the model inference.""" + + session_id: int + resp: Response + token_ids: List[int] + meta: Any = None + finish: bool = False + logits: torch.Tensor = None + + # send cache blocks back for migration in Disaggregated LLM Serving + # when Prefill Engine is Done. + cache_block_ids: List[int] = None + + # for logging + req_metrics: RequestMetrics = None + + @dataclass class SamplingParam: """Sampling parameter.""" @@ -136,17 +158,20 @@ class MessageStatus(enum.Enum): ABORTED = enum.auto() LOCKED = enum.auto() - # PD Disaggregation - # WAITING_MIGRATION: state of Unmigrated Requests - # in both prefill and decode engines are tagged by - # RUNNING_MIGRATION: state of Migrating Requests - # in decode engine + # PD Disaggregation (Prefill Engine) TO_BE_MIGRATED = enum.auto() - WAITING_MIGRATION = enum.auto() - RUNNING_MIGRATION = enum.auto() - MIGRATION_LOCKED = enum.auto() + # PD Disaggregation (Decode Engine) + META_MIGRATION_WAITING = enum.auto() + META_MIGRATION_RUNNING = enum.auto() + + MIGRATION_WAITING = enum.auto() + MIGRATION_RUNNING = enum.auto() MIGRATION_DONE = enum.auto() + RECOMPUTION_PREEMPTION = enum.auto() + REMOTE_RECOMPUTING = enum.auto() + REMOTE_RECOMPUTED = enum.auto() + _SEQ_COUNT = 0 @@ -230,7 +255,7 @@ def add_sequence(self, return_logits: bool = False, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, - migration_request: Optional[MigrationRequest] = None, + migration_context: Optional[MigrationContext] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -250,11 +275,11 @@ def add_sequence(self, num_new_tokens=0, sampling_param=sampling_param, adapter_name=adapter_name, - arrive_time=time.perf_counter(), + arrive_time=migration_context.time_stamp.arrive_time if migration_context else time.time(), history_embeddings=HistoryEmbeddings(input_embeddings), history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, - migration_request=migration_request, + migration_context=migration_context, resp_cache=resp_cache, preserve_cache=preserve_cache, ) @@ -464,7 +489,7 @@ class SchedulerSequence: model_meta: Dict[str, Any] = None # For Disaggregation - migration_request: Optional[MigrationRequest] = None + migration_context: Optional[MigrationContext] = None resp_cache: bool = False preserve_cache: bool = False diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 0be35ab9be..4483122140 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import Dict, List +from lmdeploy.pytorch.disagg.config import EngineRole + from lmdeploy.messages import EventType, ScheduleMetrics from lmdeploy.utils import get_logger, logging_timer @@ -79,13 +81,13 @@ def locked(self): @property def waiting_migration(self): """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_WAITING) return list(seq_map.values()) @property def running_migration(self): """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_RUNNING) return list(seq_map.values()) @property @@ -94,6 +96,31 @@ def migration_done(self): seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) return list(seq_map.values()) + @property + def meta_migration_waiting(self): + seq_map = self.seq_manager.get_sequences(MessageStatus.META_MIGRATION_WAITING) + return list(seq_map.values()) + + @property + def meta_migration_running(self): + seq_map = self.seq_manager.get_sequences(MessageStatus.META_MIGRATION_RUNNING) + return list(seq_map.values()) + + @property + def recomputation_preemption(self): + seq_map = self.seq_manager.get_sequences(MessageStatus.RECOMPUTION_PREEMPTION) + return list(seq_map.values()) + + @property + def remote_recomputing(self): + seq_map = self.seq_manager.get_sequences(MessageStatus.REMOTE_RECOMPUTING) + return list(seq_map.values()) + + @property + def remote_recomputed(self): + seq_map = self.seq_manager.get_sequences(MessageStatus.REMOTE_RECOMPUTED) + return list(seq_map.values()) + def build_eviction_helper(self, eviction_type: str): if eviction_type == 'copy': logger.warning('`copy` eviction has been deprecated, ' @@ -145,7 +172,7 @@ def _schedule_migration(self): def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING_MIGRATION + seq.status = MessageStatus.MIGRATION_RUNNING running_migration.append(seq) nonlocal migrating_token_count migrating_token_count += seq.num_token_ids @@ -234,7 +261,7 @@ def _reorder_waiting(): return running, swap_in_map, swap_out_map, copy_map @logging_timer('ScheduleDecoding', logger) - def _schedule_decoding(self, prealloc_size: int = 0): + def _schedule_decoding(self, prealloc_size: int = 0, engine_role:EngineRole = EngineRole.Hybrid): """Schedule decoding.""" running = self.running @@ -257,7 +284,9 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): from itertools import chain hanging = reversed(self.hanging) waiting = reversed(self.waiting) - evictable = list(chain(hanging, waiting)) + recompution_preemption = reversed(self.recomputation_preemption) + remote_recomputing = reversed(self.remote_recomputing) + evictable = list(chain(hanging, waiting, recompution_preemption, remote_recomputing)) return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) # 1. running @@ -276,7 +305,10 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): continue if not __evict_for_seq(seq, num_required_blocks): - self._set_message_status(seq, MessageStatus.WAITING) + if engine_role == EngineRole.Decode: + self._set_message_status(seq, MessageStatus.RECOMPUTION_PREEMPTION) + else: + self._set_message_status(seq, MessageStatus.WAITING) continue self.block_manager.allocate(seq, prealloc_size) @@ -284,12 +316,12 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): return self.running, swap_in_map, swap_out_map, copy_map - def schedule(self, is_prefill: bool, prealloc_size: int = 0): + def schedule(self, is_prefill: bool, prealloc_size: int = 0, engine_role: EngineRole = EngineRole.Hybrid): """Schedule inputs for next steps.""" if is_prefill: output = self._schedule_prefill() else: - output = self._schedule_decoding(prealloc_size) + output = self._schedule_decoding(prealloc_size, engine_role) running, swap_in_map, swap_out_map, copy_map = output return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map) @@ -359,6 +391,15 @@ def has_migration_waiting(self): def has_migration_done(self): return self.num_migration_done() > 0 + def has_recomputation_preempted(self): + return self.num_recomputation_preemption() > 0 + + def has_remote_recomputing(self): + return self.num_remote_recomputing() > 0 + + def has_remote_recomputed(self): + return self.num_remote_recomputed() > 0 + def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] @@ -375,21 +416,26 @@ def num_to_be_migrated(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.TO_BE_MIGRATED) - def num_migration_locked(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_LOCKED) - def num_migration_running(self): """Num migration running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING_MIGRATION) + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_RUNNING) def num_migration_done(self): """Num migration done.""" return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) + def num_recomputation_preemption(self): + return self.seq_manager.num_sequences(MessageStatus.RECOMPUTION_PREEMPTION) + + def num_remote_recomputing(self): + return self.seq_manager.num_sequences(MessageStatus.REMOTE_RECOMPUTING) + + def num_remote_recomputed(self): + return self.seq_manager.num_sequences(MessageStatus.REMOTE_RECOMPUTED) + def num_migration_waiting(self): """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_WAITING) def num_locked(self): """Num locked.""" @@ -406,18 +452,6 @@ def unlock_running(self, locked: SeqList): if seq.status == MessageStatus.LOCKED: self._set_message_status(seq, MessageStatus.RUNNING) - def lock_running_migration(self, running: SeqList): - """Lock running sequence.""" - for seq in running: - if seq.status == MessageStatus.RUNNING_MIGRATION: - self._set_message_status(seq, MessageStatus.MIGRATION_LOCKED) - - def unlock_running_migration(self, locked: SeqList): - """Unlock running migration.""" - for seq in locked: - if seq.status == MessageStatus.MIGRATION_LOCKED: - self._set_message_status(seq, MessageStatus.MIGRATION_DONE) - def collect_migration_done(self): migration_done = self.migration_done for seq in migration_done: diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 0de4e3c49d..2b0c3a8233 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -725,6 +725,10 @@ async def generate( f'max_new_tokens={gen_config.max_new_tokens}, ' f'seq_start={sequence_start}, seq_end={sequence_end}, ' f'step={step}, prep={do_preprocess}') + elif gen_config.migration_context: + # In PD Disaggregation mode, initialize token lazily + input_ids = [] + prompt_input = dict(input_ids=input_ids) else: # TODO(lvhan) VLM doesn't support input_ids as an argument. # Figure out a graceful way to handle the invalid input diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5d9f9b7c86..fc95e2a032 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -27,8 +27,7 @@ from lmdeploy.model import ChatTemplateConfig from lmdeploy.pytorch.disagg.config import DistServeEngineConfig from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, - DistServeDropConnectionRequest, DistServeInitRequest, - MigrationRequest) + DistServeDropConnectionRequest, DistServeInitRequest) from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice, @@ -348,13 +347,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - json_request = await raw_request.json() - migration_request = json_request.pop('migration_request', None) - with_cache = json_request.pop('with_cache', False) - preserve_cache = json_request.pop('preserve_cache', False) - if migration_request: - migration_request = MigrationRequest.model_validate(migration_request) - if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -409,9 +401,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, - migration_request=migration_request, - with_cache=with_cache, - preserve_cache=preserve_cache) + migration_context=request.migration_context, + with_cache=request.with_cache, + preserve_cache=request.preserve_cache) tools = None if request.tools and request.tool_choice != 'none': @@ -587,7 +579,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) choices.append(choice_data) - if with_cache: + if request.with_cache: cache_block_ids = cache_block_ids[0] remote_token_ids = [remote_token_ids[0][-1]] @@ -605,7 +597,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: usage=usage, ).model_dump() - if with_cache: + if request.with_cache: response['cache_block_ids'] = cache_block_ids response['remote_token_ids'] = remote_token_ids @@ -659,13 +651,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - json_request = await raw_request.json() - migration_request = json_request.pop('migration_request', None) - with_cache = json_request.pop('with_cache', False) - preserve_cache = json_request.pop('preserve_cache', False) - if migration_request: - migration_request = MigrationRequest.model_validate(migration_request) - if request.session_id == -1: VariableInterface.session_id += 1 request.session_id = VariableInterface.session_id @@ -700,9 +685,9 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, - migration_request=migration_request, - with_cache=with_cache, - preserve_cache=preserve_cache) + migration_context=request.migration_context, + with_cache=request.with_cache, + preserve_cache=request.preserve_cache) generators = [] for i in range(len(request.prompt)): result_generator = VariableInterface.async_engine.generate( @@ -817,7 +802,7 @@ async def _inner_call(i, generator): ) choices[i] = choice_data - if with_cache: + if request.with_cache: cache_block_ids = cache_block_ids[0] remote_token_ids = [remote_token_ids[0][-1]] @@ -836,7 +821,7 @@ async def _inner_call(i, generator): usage=usage, ).model_dump() - if with_cache: + if request.with_cache: response['cache_block_ids'] = cache_block_ids response['remote_token_ids'] = remote_token_ids @@ -941,6 +926,17 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None): """ PD Disaggregation API Begin """ +@router.get('/distserve/cache_info') +async def cache_info(): + # TODO (Jimy): disaggregate Migration and Forward Process + raise NotImplementedError + +@router.get('/distserve/model_info') +async def model_info(): + # TODO (Jimy): disaggregate Migration and Forward Process + raise NotImplementedError + + @router.get('/distserve/engine_info') async def engine_info(): engine_config = VariableInterface.async_engine.backend_config @@ -957,6 +953,18 @@ async def engine_info(): return response.model_dump_json() +@router.get('/distserve/deref_local_gpu_cache') +async def deref_local_gpu_cache(): + # TODO (Jimy): disaggregate Migration and Forward Process + raise NotImplementedError + + +@router.get('/distserve/ref_remote_gpu_cache') +async def ref_remote_gpu_cache(): + # TODO (Jimy): disaggregate Migration and Forward Process + raise NotImplementedError + + @router.post('/distserve/p2p_initialize') async def p2p_initialize(init_request: DistServeInitRequest): return VariableInterface.async_engine.p2p_initialize(init_request) @@ -974,7 +982,7 @@ async def p2p_drop_connect(drop_conn_request: DistServeDropConnectionRequest): @router.post('/distserve/free_cache') async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONResponse: - session_id = cache_free_request.remote_session_id + session_id = cache_free_request.session_id VariableInterface.async_engine.free_cache(session_id) return {'status': 'SUCCESS'} diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index b29ec09501..64e07addc3 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -7,6 +7,8 @@ import shortuuid from pydantic import BaseModel, Field +from lmdeploy.pytorch.disagg.conn.protocol import MigrationContext + class ErrorResponse(BaseModel): """Error responses.""" @@ -139,6 +141,11 @@ class ChatCompletionRequest(BaseModel): min_p: float = 0.0 enable_thinking: Optional[bool] = None + # For DistServe + with_cache: bool = False + preserve_cache: bool = False + migration_context: Optional[MigrationContext] = None + class FunctionCall(BaseModel): """Function response.""" @@ -280,6 +287,11 @@ class CompletionRequest(BaseModel): seed: Optional[int] = None min_p: float = 0.0 + # For DistServe + with_cache: bool = False + preserve_cache: bool = False + migration_context: Optional[MigrationContext] = None + class CompletionResponseChoice(BaseModel): """Completion response choices.""" diff --git a/lmdeploy/serve/proxy/constants.py b/lmdeploy/serve/proxy/constants.py index a62e8fc99b..05234097ee 100644 --- a/lmdeploy/serve/proxy/constants.py +++ b/lmdeploy/serve/proxy/constants.py @@ -18,6 +18,7 @@ class RoutingStrategy(enum.Enum): """Strategy to dispatch requests to nodes.""" RANDOM = enum.auto() + ROUND_ROBIN = enum.auto() MIN_EXPECTED_LATENCY = enum.auto() MIN_OBSERVED_LATENCY = enum.auto() @@ -26,6 +27,8 @@ def from_str(cls, name): """Get strategy from string.""" if name == 'random': return cls.RANDOM + elif name == 'round_robin': + return cls.ROUND_ROBIN elif name == 'min_expected_latency': return cls.MIN_EXPECTED_LATENCY elif name == 'min_observed_latency': diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index b5977a4724..5e7af562ce 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -8,7 +8,7 @@ import random import threading import time -from collections import deque +from collections import defaultdict, deque from http import HTTPStatus from typing import Deque, Dict, List, Literal, Optional, Union @@ -22,7 +22,7 @@ from pydantic import BaseModel, Field from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy -from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import KVTransferProtocol, MigrationContext, MigrationTimeStamp from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool from lmdeploy.pytorch.disagg.messages import PDConnectionMessage from lmdeploy.serve.openai.api_server import check_api_key, create_error_response @@ -31,6 +31,7 @@ from lmdeploy.serve.proxy.constants import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, RoutingStrategy, err_msg from lmdeploy.utils import get_logger + logger = get_logger('lmdeploy') @@ -105,7 +106,7 @@ def __init__(self, self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) # For PD Disaggregation - self.migration_protocol = MigrationProtocol[migration_protocol] + self.migration_protocol = KVTransferProtocol[migration_protocol] self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) self.pd_connection_pool = PDConnectionPool() self.dummy_prefill = False @@ -276,6 +277,18 @@ def get_matched_urls(): index = random.choices(range(len(all_matched_urls)), weights=weights)[0] url = all_matched_urls[index] return url + + elif self.routing_strategy == RoutingStrategy.ROUND_ROBIN: + if not hasattr(NodeManager.get_node_url, 'rr_counter'): + NodeManager.get_node_url.rr_counter = { + EngineRole.Hybrid: defaultdict(int), + EngineRole.Prefill: defaultdict(int), + EngineRole.Decode: defaultdict(int) + } + length = len(self.get_nodes(role).keys()) + NodeManager.get_node_url.rr_counter[role] += 1 + return list(self.get_nodes(role).keys())[NodeManager.get_node_url.rr_counter[role]%length] + elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY: all_matched_urls, all_the_speeds = get_matched_urls() if len(all_matched_urls) == 0: @@ -590,16 +603,16 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque node_manager.post_call(node_url, start) return JSONResponse(json.loads(response)) elif node_manager.serving_strategy == ServingStrategy.DistServe: + time_stamp = MigrationTimeStamp(arrive_time=time.time()) request_dict = request.model_dump() # Prefill prefill_request_dict = copy.deepcopy(request_dict) prefill_request_dict['max_tokens'] = 1 prefill_request_dict['stream'] = False - prefill_request_dict['with_cache'] = True prefill_request_dict['preserve_cache'] = True - prefill_info = {} + prefill_info = {'id':-1} p_url = 'dummy:dummy' if not node_manager.dummy_prefill: p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) @@ -627,17 +640,21 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque rdma_config=node_manager.rdma_config, )) - remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 - remote_block_ids = prefill_info.get('cache_block_ids') or [] - remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 - - request_dict['migration_request'] = MigrationRequest( + # mute messages and transfer lazily from prefill engine + # TODO (JimyMa): Support Batched Request + # Dummy messages because it can be migrated from prefill engine + request_dict['messages'] = [] + request_dict['migration_context'] = MigrationContext( protocol=node_manager.migration_protocol, - remote_engine_id=p_url, - remote_session_id=remote_session_id, - remote_block_ids=remote_block_ids, - remote_token_id=remote_token_id, - is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') + decode_engine_id=d_url, + decode_session_id=None, + decode_block_ids=[], + prefill_engine_id=p_url, + prefill_session_id=prefill_info.get('id', -1), + prefill_block_ids=[], + token_ids=[], + is_dummy_prefill=node_manager.dummy_prefill, + time_stamp=time_stamp).model_dump(mode='json') start = node_manager.pre_call(d_url) if not node_manager.dummy_prefill: @@ -717,6 +734,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None node_manager.post_call(node_url, start) return JSONResponse(json.loads(response)) elif node_manager.serving_strategy == ServingStrategy.DistServe: + time_stamp = MigrationTimeStamp(arrive_time=time.time()) request_dict = request.model_dump() # Prefill @@ -772,14 +790,23 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 remote_block_ids = prefill_info.get('cache_block_ids') or [] - remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 - request_dict['migration_request'] = MigrationRequest( + + # TODO (JimyMa): Support Batched request + # Dummy Prompt because it can be migrated from prefill engine + request_dict['prompt'] = '' + print(prefill_info) + + request_dict['migration_context'] = MigrationContext( protocol=node_manager.migration_protocol, - remote_engine_id=p_url, - remote_session_id=remote_session_id, - remote_block_ids=remote_block_ids, - remote_token_id=remote_token_id, - is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') + decode_engine_id=d_url, + decode_session_id=None, + decode_block_ids=[], + prefill_engine_id=p_url, + prefill_session_id=remote_session_id, + prefill_block_ids=remote_block_ids, + token_ids=[], + is_dummy_prefill=node_manager.dummy_prefill, + time_stamp=time_stamp).model_dump(mode='json') start = node_manager.pre_call(d_url) if not node_manager.dummy_prefill: @@ -802,7 +829,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None def proxy(server_name: str = '0.0.0.0', server_port: int = 8000, serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid', - routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', + routing_strategy: Literal['random', 'round_robin', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', api_keys: Optional[Union[List[str], str]] = None, ssl: bool = False, log_level: str = 'INFO', @@ -818,7 +845,7 @@ def proxy(server_name: str = '0.0.0.0', server_port (str): the server port. Default to 8000. serving_strategy ('Hybrid' | 'DistServe'): the strategy to serving. Hybrid default. DistServe for PD Disaggregation. - route_strategy ('random' | 'min_expected_latency' | 'min_observed_latency'): + route_strategy ('random' | 'round_robin' | 'min_expected_latency' | 'min_observed_latency'): the strategy to dispatch requests to nodes. Default to 'min_expected_latency' api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as @@ -831,7 +858,7 @@ def proxy(server_name: str = '0.0.0.0', """ # noqa node_manager.serving_strategy = ServingStrategy[serving_strategy] node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy) - node_manager.migration_protocol = MigrationProtocol[migration_protocol] + node_manager.migration_protocol = KVTransferProtocol[migration_protocol] node_manager.dummy_prefill = dummy_prefill node_manager.rdma_config = DistServeRDMAConfig(