diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 74f268b5b1..80003c9ded 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -232,6 +232,7 @@ def add_parser_proxy(): default='Hybrid', help='the strategy to serve, Hybrid for colocating Prefill and Decode' 'workloads into same engine, DistServe for Prefill-Decode Disaggregation') + parser.add_argument('--dummy-prefill', action='store_true', help='dummy prefill for performance profiler') parser.add_argument('--routing-strategy', type=str, choices=['random', 'min_expected_latency', 'min_observed_latency'], diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index af7c091770..cb2dca6b68 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -8,7 +8,7 @@ from pydantic.dataclasses import dataclass as pydantic_dataclass from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend -from lmdeploy.pytorch.disagg.request import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest from .tokenizer import Tokenizer from .utils import get_logger diff --git a/lmdeploy/pytorch/disagg/backend/__init__.py b/lmdeploy/pytorch/disagg/backend/__init__.py index 1b45840333..3ab02d3bd6 100644 --- a/lmdeploy/pytorch/disagg/backend/__init__.py +++ b/lmdeploy/pytorch/disagg/backend/__init__.py @@ -15,10 +15,4 @@ except ImportError: logger.warning('Disable Mooncake Backend') -try: - logger.debug('Registering InfiniStoreBackend Backend') - from .infinistore import InfiniStoreBackend -except ImportError: - logger.warning('Disable InfiniStoreBackend Backend') - -__all__ = ['DLSlimeBackend', 'MooncakeBackend', 'InfiniStoreBackend'] +__all__ = ['DLSlimeBackend', 'MooncakeBackend'] diff --git a/lmdeploy/pytorch/disagg/backend/base.py b/lmdeploy/pytorch/disagg/backend/base.py index 200443d127..7e7716dffc 100644 --- a/lmdeploy/pytorch/disagg/backend/base.py +++ b/lmdeploy/pytorch/disagg/backend/base.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod -from lmdeploy.pytorch.disagg.config import MigrationProtocol +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, + MigrationProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest class MigrationBackendImpl: @@ -21,7 +21,7 @@ def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): return NotImplementedError @abstractmethod - def p2p_connect(self, conn_req: DistServeConnectionRequest): + def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo): raise NotImplementedError @abstractmethod diff --git a/lmdeploy/pytorch/disagg/backend/dlslime.py b/lmdeploy/pytorch/disagg/backend/dlslime.py index d3421fd0a9..80257890b7 100644 --- a/lmdeploy/pytorch/disagg/backend/dlslime.py +++ b/lmdeploy/pytorch/disagg/backend/dlslime.py @@ -10,9 +10,10 @@ from lmdeploy.logger import get_logger from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl -from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend, MigrationProtocol +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, + MigrationProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest logger = get_logger('lmdeploy') @@ -60,8 +61,8 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage register_mr_request.offset, register_mr_request.length) - def connect(self, connect_request: DistServeConnectionRequest): - self.endpoint[connect_request.protocol].connect(json.loads(connect_request.remote_endpoint_info)) + def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo): + self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info)) async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False): batch = [ @@ -91,6 +92,7 @@ def split(batch: List[DLSlimeAssignment]): @MIGRATION_BACKENDS.register_module(MigrationBackend.DLSlime.name) class DLSlimeBackend(MigrationBackendImpl): + """DLSlime Transfer Engine.""" def __init__(self): self.links: Dict[int, DLSlimeMigrationManagement] = {} @@ -104,8 +106,8 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): return self.links[remote_engine_id].endpoint[protocol].endpoint_info - def p2p_connect(self, conn_req: DistServeConnectionRequest): - self.links[conn_req.remote_engine_id].connect(conn_req) + def p2p_connect(self, remote_engine_id: str, conn_req: DistServeKVTransferEndpointInfo): + self.links[remote_engine_id].connect(conn_req) async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op) diff --git a/lmdeploy/pytorch/disagg/backend/infinistore.py b/lmdeploy/pytorch/disagg/backend/infinistore.py deleted file mode 100644 index f75850138f..0000000000 --- a/lmdeploy/pytorch/disagg/backend/infinistore.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS -from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl -from lmdeploy.pytorch.disagg.config import MigrationBackend, MigrationProtocol -from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest - - -@MIGRATION_BACKENDS.register_module(MigrationBackend.InfiniStore.name) -class InfiniStoreBackend(MigrationBackendImpl): - - def p2p_initialize(self, init_request: DistServeInitRequest): - raise NotImplementedError - - def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage): - raise NotImplementedError - - def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): - return NotImplementedError - - def p2p_connect(self, conn_req: DistServeConnectionRequest): - raise NotImplementedError - - def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): - raise NotImplementedError - - def store(self, assignment: MigrationAssignment, async_op: bool = False): - raise NotImplementedError - - def load(self, assignment: MigrationAssignment, async_op: bool = False): - raise NotImplementedError diff --git a/lmdeploy/pytorch/disagg/backend/mooncake.py b/lmdeploy/pytorch/disagg/backend/mooncake.py index c6f4bb4f19..e4ba7fbd5f 100644 --- a/lmdeploy/pytorch/disagg/backend/mooncake.py +++ b/lmdeploy/pytorch/disagg/backend/mooncake.py @@ -8,9 +8,10 @@ from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl -from lmdeploy.pytorch.disagg.config import MigrationBackend, MigrationProtocol, MooncakeEngineConfig +from lmdeploy.pytorch.disagg.config import MigrationBackend, MooncakeEngineConfig +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeInitRequest, DistServeKVTransferEndpointInfo, + MigrationProtocol) from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -160,9 +161,9 @@ def endpoint_info(self) -> Dict: return endpoint_info - def connect(self, connect_request: DistServeConnectionRequest): + def connect(self, connect_request: DistServeKVTransferEndpointInfo): """Connect to the remote engine.""" - remote_endpoint_info = json.loads(connect_request.remote_endpoint_info) + remote_endpoint_info = json.loads(connect_request.endpoint_info) self.remote_url = remote_endpoint_info['session_id'] self.remote_kv_table = remote_endpoint_info['mr_info'] @@ -247,8 +248,8 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol): return self.links[remote_engine_id].endpoint_info - def p2p_connect(self, connect_request: DistServeConnectionRequest): - self.links[connect_request.remote_engine_id].connect(connect_request) + def p2p_connect(self, remote_engine_id: str, connect_request: DistServeKVTransferEndpointInfo): + self.links[remote_engine_id].connect(connect_request) async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False): await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op) diff --git a/lmdeploy/pytorch/disagg/config.py b/lmdeploy/pytorch/disagg/config.py index c1e570fcbc..f4dd002231 100644 --- a/lmdeploy/pytorch/disagg/config.py +++ b/lmdeploy/pytorch/disagg/config.py @@ -42,24 +42,6 @@ class MigrationBackend(enum.Enum): DLSlime = enum.auto() Mooncake = enum.auto() - InfiniStore = enum.auto() - - -class MigrationProtocol(enum.Enum): - """Migration Transport Protocol. - - Attributes: - TCP: TCP for General Purpose Transport Protocol. - RDMA: IB or RoCEv1/v2. - NVLINK: High device-to-device link. - - Warning: By now, only `GPU Directed RDMA` is supported in DistServe. - We preserve several protocol and will be implemented in the future. - """ - - TCP = enum.auto() - RDMA = enum.auto() - NVLINK = enum.auto() class RDMALinkType(enum.Enum): @@ -133,13 +115,3 @@ class MooncakeEngineConfig(DistServeEngineConfig): TODO: Support more specific config for Mooncake. """ pass - - -class DistServeConfig(BaseModel): - """DistServe Config.""" - - serving_strategy: ServingStrategy - distserve_transport_protocol: MigrationProtocol - rdma_config: Optional[DistServeRDMAConfig] = None - nvlink_config: Optional[DistServeNVLinkConfig] = None - tcp_config: Optional[DistServeTCPConfig] = None diff --git a/lmdeploy/pytorch/disagg/conn.py b/lmdeploy/pytorch/disagg/conn.py deleted file mode 100644 index 263ed873e0..0000000000 --- a/lmdeploy/pytorch/disagg/conn.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import asyncio -import enum -import json -import os -from typing import Dict, List, Tuple - -import aiohttp - -from lmdeploy.logger import get_logger -from lmdeploy.pytorch.disagg.config import DistServeEngineConfig -from lmdeploy.pytorch.disagg.messages import PDConnectionMessage -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest - -logger = get_logger('lmdeploy') - -AIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None) - - -class PDConnectionStatus(enum.Enum): - Disconnected = enum.auto() - Connected = enum.auto() - Connecting = enum.auto() - - -class PDConnectionState: - """PDConnectionState.""" - - def __init__(self, status: PDConnectionStatus, event: asyncio.Event): - self.status = status - self.event = event - - async def wait(self): - await self.event.wait() - - def set_status(self, status: PDConnectionStatus): - self.status = status - - -class PDConnectionPool: - """Constructing the link of Prefill and Decode engine for the migration of - KVCache. - - Note: we use Peer to Peer transportation in KVCache migration. - Note: Lazy link construction is supported, which perform connection - at the first LLM request. As a result, we don't need to construct - PD Communication group when start a engine server. - Warning: By now, only engines with same parallel configuration can be - correctly connected. - """ - - def __init__(self): - # Links of PD Connection. - self.pool: Dict[Tuple[str, str], PDConnectionState] = {} - - # conn_perform handler queue - self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue()) - - # conn Registry Lock - self.conn_lock = asyncio.Lock() - - # Connection Retry when failure - self.max_retry_cnt = 8 - - # trigger signal when conn request arrive. - self.conn_req_event = asyncio.Event() - - # conn initialized signal - self.initialized = False - - async def perform_conn(self): - - def get_server_api(url: str, api: str): - return f'{url}/{api}' - - async def get_engine_config(server_endpoint): - async with self.conn_sem: - async with self.conn_sess.get( - get_server_api(server_endpoint, 'distserve/engine_info'), - timeout=self.aiotimeout, - ) as resp: - return DistServeEngineConfig.model_validate_json(await resp.json()) - - async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest): - async with self.conn_sem: - async with self.conn_sess.post( - get_server_api(server_endpoint, 'distserve/p2p_initialize'), - json=init_request.model_dump(mode='json'), - timeout=self.aiotimeout, - ) as resp: - return await resp.json() - - async def p2p_connect(server_endpoint, conn_request: List[DistServeConnectionRequest]): - async with self.conn_sem: - async with self.conn_sess.post( - get_server_api(server_endpoint, 'distserve/p2p_connect'), - json=[req.model_dump(mode='json') for req in conn_request], - timeout=self.aiotimeout, - ) as resp: - return await resp.json() - - async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): - try: - link = (conn_req.p_url, conn_req.d_url) - logger.debug(f'{link} connecting...') - # Step 1. Get Remote Engine Configuration - prefill_engine_config = await get_engine_config(conn_req.p_url) - decode_engine_config = await get_engine_config(conn_req.d_url) - - # Note: Only Same Parallel Configurations are supported by now - assert prefill_engine_config.tp_size == decode_engine_config.tp_size - - # Step 2. Construct Initialize Configuration - prefill_init_req = DistServeInitRequest( - protocol=conn_req.protocol, - local_engine_id=conn_req.p_url, - local_engine_config=prefill_engine_config, - remote_engine_id=conn_req.d_url, - remote_engine_config=decode_engine_config, - rdma_config=conn_req.rdma_config, - nvlink_config=conn_req.nvlink_config, - ) - decode_init_req = DistServeInitRequest(protocol=conn_req.protocol, - local_engine_id=conn_req.d_url, - local_engine_config=decode_engine_config, - remote_engine_id=conn_req.p_url, - remote_engine_config=prefill_engine_config, - rdma_config=conn_req.rdma_config, - nvlink_config=conn_req.nvlink_config) - - prefill_endpoint_info = await p2p_initialize(conn_req.p_url, prefill_init_req) - decode_endpoint_info = await p2p_initialize(conn_req.d_url, decode_init_req) - - # Step 3. Connection - prefill_endpoint_conn_reqs = [ - DistServeConnectionRequest( - protocol=conn_req.protocol, - remote_engine_id=conn_req.d_url, - remote_endpoint_info=json.dumps(info), - ) for info in decode_endpoint_info - ] - decode_endpoint_conn_reqs = [ - DistServeConnectionRequest( - protocol=conn_req.protocol, - remote_engine_id=conn_req.p_url, - remote_endpoint_info=json.dumps(info), - ) for info in prefill_endpoint_info - ] - await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs) - await p2p_connect(conn_req.d_url, decode_endpoint_conn_reqs) - self.pool[link].set_status(PDConnectionStatus.Connected) - logger.debug(f'{(conn_req.p_url, conn_req.d_url)} connected') - except Exception as e: - self.pool[link].set_status(PDConnectionStatus.Disconnected) - logger.error(f'pd connection error: {e}') - conn_event.set() - - async def wait_for_conn(conn_req: PDConnectionMessage, conn_event: asyncio.Event): - await self.pool[(conn_req.p_url, conn_req.d_url)].event.wait() - conn_event.set() - - logger.debug('perform_conn start') - while True: - if self.waiting_conn.empty(): - await self.conn_req_event.wait() - - self.conn_req_event.clear() - - while not self.waiting_conn.empty(): - conn_req, conn_event = self.waiting_conn.get_nowait() - link = (conn_req.p_url, conn_req.d_url) - if link not in self.pool: - self.pool[link] = PDConnectionState( - PDConnectionStatus.Disconnected, - conn_event, - ) - if self.pool[link].status == PDConnectionStatus.Connecting: - asyncio.create_task(wait_for_conn(conn_req, conn_event)) - elif self.pool[link].status == PDConnectionStatus.Disconnected: - self.pool[link].set_status(PDConnectionStatus.Connecting) - asyncio.create_task(conn_worker(conn_req, conn_event)) - - async def connect(self, conn_req: PDConnectionMessage): - if not self.initialized: - loop = asyncio.get_event_loop() - loop.create_task(self.perform_conn()) - self.conn_sem = asyncio.Semaphore(1024) - self.conn_sess = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit_per_host=256), - timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), - ) - self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) - self.initialized = True - cnt = 0 - while cnt < self.max_retry_cnt: - if self.is_connected(conn_req.p_url, conn_req.d_url): - return - if cnt > 0: - logger.warning(f'Connection failure, retry cnt: {cnt}') - conn_event = asyncio.Event() - self.waiting_conn.put_nowait((conn_req, conn_event)) - self.conn_req_event.set() - await conn_event.wait() - cnt += 1 - async with self.conn_lock: - self.pool[conn_req.p_url, conn_req.d_url].set_status(PDConnectionStatus.Disconnected) - raise TimeoutError('PDConnection Failure') - - def is_connected(self, p_url: str, d_url: str): - link = self.pool.get((p_url, d_url), None) - if not link: - return False - return link.status == PDConnectionStatus.Connected - - def drop(self, left: str, right: str): - self.pool.pop((left, right), None) diff --git a/lmdeploy/pytorch/disagg/conn/__init__.py b/lmdeploy/pytorch/disagg/conn/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/lmdeploy/pytorch/disagg/conn/engine_conn.py b/lmdeploy/pytorch/disagg/conn/engine_conn.py new file mode 100644 index 0000000000..a14c684f3b --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/engine_conn.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import os +from typing import TYPE_CHECKING, Dict, List +from urllib.parse import urlparse + +import zmq +import zmq.asyncio + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeConnectionResponse, DistServeConnectionStatus, + DistServeDropConnectionRequest, DistServeEngineEndpointInfo, + DistServeInitRequest, DistServeInitResponse, + DistServeKVTransferEndpointInfo) +from lmdeploy.pytorch.engine.executor.dist_utils import find_available_port + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.engine import Engine + +logger = get_logger('lmdeploy') + + +class EngineP2PConnection: + + def __init__(self, engine: 'Engine'): + self.engine: Engine = engine + self.p2p_conn_ctx: Dict[str, zmq.asyncio.Context] = {} + self.p2p_sender: Dict[str, zmq.asyncio.Socket] = {} + self.p2p_receiver: Dict[str, zmq.asyncio.Socket] = {} + + self.use_unique_kvtransfer_engine = os.environ.get('LMDEPLOY_USE_UNIQUE_KVTRANSFER_ENGINE', False) + + async def p2p_initialize(self, init_request: DistServeInitRequest): + ctx = zmq.asyncio.Context(2) + sender = ctx.socket(zmq.PUSH) + sender_port = find_available_port() + sender_hostname = urlparse(init_request.local_engine_id).hostname + zmq_address = f'tcp://{sender_hostname}:{sender_port}' + sender.bind(zmq_address) + receiver = ctx.socket(zmq.PULL) + + self.p2p_conn_ctx[init_request.remote_engine_id] = ctx + self.p2p_sender[init_request.remote_engine_id] = sender + self.p2p_receiver[init_request.remote_engine_id] = receiver + + kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] = self.engine.executor.p2p_initialize( + init_request) + + return DistServeInitResponse(engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_address), + kvtransfer_endpoint_info=kvtransfer_endpoint_info, + status=DistServeConnectionStatus.SUCCESS) + + def p2p_connect(self, conn_request: DistServeConnectionRequest): + self.p2p_receiver[conn_request.remote_engine_id].connect(conn_request.remote_engine_endpoint_info.zmq_address) + self.engine.executor.p2p_connect(remote_engine_id=conn_request.remote_engine_id, + conn_request=conn_request.remote_kvtransfer_endpoint_info) + event_loop = asyncio.get_event_loop() + event_loop.create_task(self.handle_zmq_recv(conn_request.remote_engine_id)) + return DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS) + + def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): + # TODO (JimyMa): drop RDMA Connection + self.zmq_disconnect(drop_conn_request.remote_engine_id) + return {'success': True} + + async def zmq_send(self, remote_engine_id: str, remote_session_id: int): + await self.p2p_sender[remote_engine_id].send_pyobj( + DistServeCacheFreeRequest(remote_engine_id=remote_engine_id, remote_session_id=remote_session_id)) + + async def handle_zmq_recv(self, remote_engine_id: str): + while True: + req: DistServeCacheFreeRequest = await self.p2p_receiver[remote_engine_id].recv_pyobj() + if isinstance(req, DistServeCacheFreeRequest): + session_id = req.remote_session_id + if session_id in self.engine.scheduler.sessions: + self.engine.scheduler.end_session(session_id=session_id) + else: + logger.error(f'invalid free, {remote_engine_id}, {session_id}') + else: + raise ValueError(f'Unsupported zmq request {type(req)}') + + async def zmq_disconnect(self, remote_engine_id: str): + self.p2p_receiver[remote_engine_id].close() + self.p2p_sender[remote_engine_id].close() + self.p2p_conn_ctx[remote_engine_id].term() diff --git a/lmdeploy/pytorch/disagg/conn/protocol.py b/lmdeploy/pytorch/disagg/conn/protocol.py new file mode 100644 index 0000000000..aa47789497 --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/protocol.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import enum +from typing import List, Optional + +from pydantic import BaseModel + +from lmdeploy.pytorch.disagg.config import (DistServeEngineConfig, DistServeNVLinkConfig, DistServeRDMAConfig, + DistServeTCPConfig) + + +class MigrationProtocol(enum.Enum): + """Migration Transport Protocol. + + Attributes: + RDMA: IB or RoCEv1/v2. + NVLINK: High device-to-device link. + + Warning: By now, only `GPU Directed RDMA` is supported in DistServe. + We preserve several protocol and will be implemented in the future. + """ + + TCP = enum.auto() + RDMA = enum.auto() + NVLINK = enum.auto() + + +class DistServeConnectionStatus(enum.Enum): + # TODO(JimyMa): Add more connection failure handler + SUCCESS = enum.auto() + FAIL = enum.auto() + + +class DistServeInitRequest(BaseModel): + local_engine_id: str + local_engine_config: DistServeEngineConfig + + remote_engine_id: str + remote_engine_config: DistServeEngineConfig + + protocol: MigrationProtocol + + rank: Optional[int] = None + + tcp_config: Optional[DistServeTCPConfig] = None + rdma_config: Optional[DistServeRDMAConfig] = None + nvlink_config: Optional[DistServeNVLinkConfig] = None + + +class DistServeEngineEndpointInfo(BaseModel): + zmq_address: str + + +class DistServeKVTransferEndpointInfo(BaseModel): + protocol: MigrationProtocol + endpoint_info: str + + +class DistServeInitResponse(BaseModel): + status: DistServeConnectionStatus + # the control plane initialization feedback + engine_endpoint_info: DistServeEngineEndpointInfo + # the KVCache Transfer initialization feedback + # To ensure generality (where endpoint_info can be initialization information + # for different media such as RDMA, NVLink, etc.), we use a string (str) to + # store this information. + kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] + + +class DistServeConnectionRequest(BaseModel): + protocol: MigrationProtocol + remote_engine_id: str + remote_engine_endpoint_info: DistServeEngineEndpointInfo + remote_kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] + + +class DistServeConnectionResponse(BaseModel): + status: DistServeConnectionStatus + + +class MigrationRequest(BaseModel): + protocol: MigrationProtocol + + remote_engine_id: str + remote_session_id: int + remote_token_id: int + remote_block_ids: List[int] + + is_dummy_prefill: bool = False + + +class DistServeCacheFreeRequest(BaseModel): + remote_engine_id: str + remote_session_id: int + + +class DistServeDropConnectionRequest(BaseModel): + engine_id: str + remote_engine_id: str diff --git a/lmdeploy/pytorch/disagg/conn/proxy_conn.py b/lmdeploy/pytorch/disagg/conn/proxy_conn.py new file mode 100644 index 0000000000..a07d281248 --- /dev/null +++ b/lmdeploy/pytorch/disagg/conn/proxy_conn.py @@ -0,0 +1,301 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import enum +import os +from collections import defaultdict +from typing import Dict, Set, Tuple + +import aiohttp +import requests + +from lmdeploy.logger import get_logger +from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, EngineRole +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeConnectionResponse, DistServeDropConnectionRequest, + DistServeInitRequest, DistServeInitResponse) +from lmdeploy.pytorch.disagg.messages import PDConnectionMessage + +logger = get_logger('lmdeploy') + +AIOHTTP_TIMEOUT = os.getenv('AIOHTTP_TIMEOUT', None) + + +class PDConnectionStatus(enum.Enum): + Disconnected = enum.auto() + Connected = enum.auto() + Connecting = enum.auto() + + +class PDConnectionState: + """PDConnectionState.""" + + def __init__(self, status: PDConnectionStatus, event: asyncio.Event): + self.status = status + self.event = event + + async def wait(self): + await self.event.wait() + + def set_status(self, status: PDConnectionStatus): + self.status = status + + +def get_server_api(url: str, api: str): + return f'{url}/{api}' + + +class PDConnectionPool: + """Constructing the link of Prefill and Decode engine for the migration of + KVCache. + + Note: we use Peer to Peer transportation in KVCache migration. + Note: Lazy link construction is supported, which perform connection + at the first LLM request. As a result, we don't need to construct + PD Communication group when start a engine server. + Note: we perform simple fault tolerance by checkpointing the session_id of a + request which is under migrating and will trigger `gc` when the decode + instanceis crushed. + TODO (JimyMa): By now, only engines with same parallel configuration can be + correctly connected. + """ + + # Maximum concurrent connections​​ + CONN_SEMAPHORE_SIZE = 2048 + + def __init__(self): + # all prefill and decode instances + # TODO (JimyMa): Maybe encoding instances + self.prefill_endpoints: Set[str] = set() + self.decode_endpoints: Set[str] = set() + + # Links of PD Connection. + self.pool: Dict[Tuple[str, str], PDConnectionState] = {} + + # put migrating session to `self.migration_session_shelf` for increasing fault tolerance + # if a session is finished, then pop it from `self.migration_session_shelf` + # if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance. + self.migration_session_shelf: Dict[str, Set[int]] = defaultdict(set) + + # conn_perform handler queue + self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue()) + + # conn Registry Lock + self.conn_lock = asyncio.Lock() + + # Connection Retry when failure + self.max_retry_cnt = 8 + + # trigger signal when conn request arrive. + self.conn_req_event = asyncio.Event() + + # conn initialized signal + self.initialized = False + + def reg_instance(self, role: EngineRole, endpoint: str): + if role == EngineRole.Prefill: + self.prefill_endpoints.add(endpoint) + elif role == EngineRole.Decode: + self.decode_endpoints.add(endpoint) + else: + raise ValueError(f'Unsupported role: {role}') + + def dereg_instance(self, endpoint: str): + if endpoint in self.prefill_endpoints: + self.prefill_endpoints.remove(endpoint) + elif endpoint in self.decode_endpoints: + dropped_key = [] + for conn_key in self.pool.keys(): + if conn_key[1] == endpoint: + dropped_key.append(conn_key) + for k in dropped_key: + self.drop(k) + # TODO(JimyMa): handle side-effect by kvcache migration + self.decode_endpoints.remove(endpoint) + + def shelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int): + self.migration_session_shelf[conn_key].add(session_id) + + def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int): + self.migration_session_shelf[conn_key].remove(session_id) + + async def connect(self, conn_req: PDConnectionMessage): + + async def get_engine_config(server_endpoint): + async with self.conn_sem: + async with self.conn_sess.get( + get_server_api(server_endpoint, 'distserve/engine_info'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + return DistServeEngineConfig.model_validate_json(result) + + async def p2p_initialize(server_endpoint, init_request: DistServeInitRequest) -> DistServeInitResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_initialize'), + json=init_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + return DistServeInitResponse.model_validate(result) + + async def p2p_connect(server_endpoint, conn_request: DistServeConnectionRequest) -> DistServeConnectionResponse: + async with self.conn_sem: + async with self.conn_sess.post( + get_server_api(server_endpoint, 'distserve/p2p_connect'), + json=conn_request.model_dump(mode='json'), + timeout=self.aiotimeout, + ) as resp: + result = await resp.json() + return DistServeConnectionResponse.model_validate(result) + + async def conn_worker(conn_req: PDConnectionMessage, conn_event: asyncio.Event): + try: + link = (conn_req.p_url, conn_req.d_url) + logger.debug(f'{link} connecting...') + # Step 1. Get Remote Engine Configuration + prefill_engine_config = await get_engine_config(conn_req.p_url) + decode_engine_config = await get_engine_config(conn_req.d_url) + + # Note: Only Same Parallel Configurations are supported by now + assert prefill_engine_config.tp_size == decode_engine_config.tp_size + + # Step 2. Construct Initialize Configuration + prefill_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.p_url, + local_engine_config=prefill_engine_config, + remote_engine_id=conn_req.d_url, + remote_engine_config=decode_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + decode_init_req = DistServeInitRequest( + protocol=conn_req.protocol, + local_engine_id=conn_req.d_url, + local_engine_config=decode_engine_config, + remote_engine_id=conn_req.p_url, + remote_engine_config=prefill_engine_config, + rdma_config=conn_req.rdma_config, + nvlink_config=conn_req.nvlink_config, + ) + + prefill_init_resp = await p2p_initialize(conn_req.p_url, prefill_init_req) + decode_init_resp = await p2p_initialize(conn_req.d_url, decode_init_req) + + # Step 3. Connection + prefill_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.d_url, + remote_engine_endpoint_info=decode_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=decode_init_resp.kvtransfer_endpoint_info) + decode_endpoint_conn_reqs = DistServeConnectionRequest( + protocol=conn_req.protocol, + remote_engine_id=conn_req.p_url, + remote_engine_endpoint_info=prefill_init_resp.engine_endpoint_info, + remote_kvtransfer_endpoint_info=prefill_init_resp.kvtransfer_endpoint_info) + await p2p_connect(conn_req.p_url, prefill_endpoint_conn_reqs) + await p2p_connect(conn_req.d_url, decode_endpoint_conn_reqs) + self.pool[link].set_status(PDConnectionStatus.Connected) + logger.debug(f'{(conn_req.p_url, conn_req.d_url)} connected') + except Exception as e: + self.pool[link].set_status(PDConnectionStatus.Disconnected) + logger.error(f'pd connection error: {e}') + conn_event.set() + + async def wait_for_conn(conn_req: PDConnectionMessage, conn_event: asyncio.Event): + await self.pool[(conn_req.p_url, conn_req.d_url)].event.wait() + conn_event.set() + + async def _perform_conn(): + logger.debug('perform_conn start') + while True: + if self.waiting_conn.empty(): + await self.conn_req_event.wait() + + self.conn_req_event.clear() + + while not self.waiting_conn.empty(): + conn_req, conn_event = self.waiting_conn.get_nowait() + link = (conn_req.p_url, conn_req.d_url) + if link not in self.pool: + self.pool[link] = PDConnectionState( + PDConnectionStatus.Disconnected, + conn_event, + ) + if self.pool[link].status == PDConnectionStatus.Connecting: + asyncio.create_task(wait_for_conn(conn_req, conn_event)) + elif self.pool[link].status == PDConnectionStatus.Disconnected: + self.pool[link].set_status(PDConnectionStatus.Connecting) + asyncio.create_task(conn_worker(conn_req, conn_event)) + + if not self.initialized: + loop = asyncio.get_event_loop() + loop.create_task(_perform_conn()) + self.conn_sem = asyncio.Semaphore(self.CONN_SEMAPHORE_SIZE) + self.conn_sess = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit_per_host=256), + timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), + ) + self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + self.initialized = True + + self.reg_instance(EngineRole.Prefill, conn_req.p_url) + self.reg_instance(EngineRole.Decode, conn_req.d_url) + + cnt = 0 + while cnt < self.max_retry_cnt: + if self.is_connected(conn_req.p_url, conn_req.d_url): + return + if cnt > 0: + logger.warning(f'Connection failure, retry cnt: {cnt}') + conn_event = asyncio.Event() + self.waiting_conn.put_nowait((conn_req, conn_event)) + self.conn_req_event.set() + await conn_event.wait() + cnt += 1 + async with self.conn_lock: + self.pool[conn_req.p_url, conn_req.d_url].set_status(PDConnectionStatus.Disconnected) + raise TimeoutError('PDConnection Failure') + + def is_connected(self, p_url: str, d_url: str): + link = self.pool.get((p_url, d_url), None) + if not link: + return False + return link.status == PDConnectionStatus.Connected + + def drop(self, pd_key: Tuple[str, str]): + left = pd_key[0] + right = pd_key[1] + + def cache_free(server_endpoint, cache_free_request: DistServeCacheFreeRequest) -> Dict: + try: + requests.post(get_server_api(server_endpoint, 'distserve/free_cache'), + json=cache_free_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error cache block free {server_endpoint, cache_free_request}. ErrorMsg: {str(e)}') + + def drop_connect(server_endpoint: str, p2p_disconnect_request: DistServeDropConnectionRequest): + try: + requests.post(get_server_api(server_endpoint, 'distserve/p2p_drop_connect'), + json=p2p_disconnect_request.model_dump(mode='json')) + except Exception as e: + logger.warning(f'error drop connect {server_endpoint, p2p_disconnect_request}. ErrorMsg: {str(e)}') + + # trigger gc + logger.warning('cache block gc triggered.') + try: + for session_id in self.migration_session_shelf[(left, right)]: + cache_free(left, DistServeCacheFreeRequest(remote_engine_id=left, remote_session_id=session_id)) + except Exception as e: + logger.warning(f'gc error, ErrorMsg: {str(e)}') + + # trigger p2p disconnect + logger.warning('drop connection triggered.') + try: + drop_connect(left, DistServeDropConnectionRequest(engine_id=left, remote_engine_id=right)) + drop_connect(right, DistServeDropConnectionRequest(engine_id=right, remote_engine_id=left)) + except Exception as e: + logger.warning(f'p2p disconnect error, ErrorMsg: {str(e)}') + + self.pool.pop((left, right), None) diff --git a/lmdeploy/pytorch/disagg/messages.py b/lmdeploy/pytorch/disagg/messages.py index b49769102d..9dac0b0391 100644 --- a/lmdeploy/pytorch/disagg/messages.py +++ b/lmdeploy/pytorch/disagg/messages.py @@ -3,8 +3,8 @@ from pydantic import BaseModel -from lmdeploy.pytorch.disagg.config import (DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig, - MigrationProtocol) +from lmdeploy.pytorch.disagg.config import DistServeNVLinkConfig, DistServeRDMAConfig, DistServeTCPConfig +from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol class MigrationExecutionBatch(BaseModel): diff --git a/lmdeploy/pytorch/disagg/request.py b/lmdeploy/pytorch/disagg/request.py deleted file mode 100644 index 990fe814bf..0000000000 --- a/lmdeploy/pytorch/disagg/request.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional - -from pydantic import BaseModel - -from lmdeploy.pytorch.disagg.config import (DistServeEngineConfig, DistServeNVLinkConfig, DistServeRDMAConfig, - DistServeTCPConfig, MigrationProtocol) - - -class DistServeConnectionRequest(BaseModel): - protocol: MigrationProtocol - remote_engine_id: str - remote_endpoint_info: str - - -class DistServeInitRequest(BaseModel): - local_engine_id: str - local_engine_config: DistServeEngineConfig - - remote_engine_id: str - remote_engine_config: DistServeEngineConfig - - protocol: MigrationProtocol - - rank: Optional[int] = None - - tcp_config: Optional[DistServeTCPConfig] = None - rdma_config: Optional[DistServeRDMAConfig] = None - nvlink_config: Optional[DistServeNVLinkConfig] = None - - -class MigrationRequest(BaseModel): - protocol: MigrationProtocol - - remote_engine_id: str - remote_session_id: int - remote_token_id: int - remote_block_ids: List[int] diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index b769d1892a..fe61572a20 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm +import json from typing import Dict, List, Literal, Optional, Tuple import torch @@ -7,9 +8,9 @@ from lmdeploy.pytorch.backends import get_backend from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import (AssignmentInstruct, DistServeRegisterMRMessage, MigrationAssignment, MigrationExecutionBatch) -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.utils import get_logger from ..config import CacheConfig, ModelConfig @@ -315,7 +316,7 @@ def get_cache_block_size(cls, """ Metheds for PD Disaggregation Begin. """ - def p2p_initialize(self, migration_init_request: DistServeInitRequest): + def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistServeKVTransferEndpointInfo: if not self.migration_backend_impl: self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]() migration_init_request.rank = self.rank @@ -330,11 +331,14 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest): offset=t.storage_offset(), length=t.numel() * t.itemsize) self.migration_backend_impl.register_memory_region(register_mr_request) - return self.migration_backend_impl.endpoint_info(migration_init_request.remote_engine_id, - migration_init_request.protocol) - - def p2p_connect(self, migration_conn_request: DistServeConnectionRequest): - self.migration_backend_impl.p2p_connect(migration_conn_request[self.tp_rank]) + return DistServeKVTransferEndpointInfo(protocol=migration_init_request.protocol, + endpoint_info=json.dumps( + self.migration_backend_impl.endpoint_info( + migration_init_request.remote_engine_id, + migration_init_request.protocol))) + + def p2p_connect(self, remote_engine_id: str, migration_conn_request: List[DistServeKVTransferEndpointInfo]): + self.migration_backend_impl.p2p_connect(remote_engine_id, migration_conn_request[self.tp_rank]) async def migrate(self, migration_execution_inputs: MigrationExecutionBatch): diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9fd87647f7..ba66b615ba 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -11,6 +11,9 @@ from lmdeploy.messages import MetricsInfo, PytorchEngineConfig, ResponseType from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, + DistServeInitRequest) from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer @@ -383,8 +386,13 @@ def __init__(self, self._start_loop() self._loop_main = None - # for migration loop management + # for PD Disaggregation + # For migrating prefill request to decode engine self.migration_event: asyncio.Event = None + # For backpressure prefill request when cache is full + self.perfill_watermark_event: asyncio.Event = None + + self.engine_conn = EngineP2PConnection(self) @classmethod def from_pretrained(cls, @@ -754,7 +762,7 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped msg.update_token_ids(update_token, model_meta=model_meta) msg.num_new_tokens += 1 if stop: - msg.status = MessageStatus.STOPPED + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): @@ -789,7 +797,7 @@ def _make_infer_outputs(self, new_token_timestamp: float, next_token_ids: torch. if not is_run[idx]: continue token_ids = msg.all_ids[-msg.num_new_tokens:] - finish = msg.status == MessageStatus.STOPPED + finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED if not finish and len(token_ids) == 0: continue session_id = msg.session_id @@ -880,7 +888,8 @@ def __need_logits(seqs: SeqList): swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map - assert len(running) > 0 + if len(running) == 0: + return None # create inputs inputs = self.create_model_inputs(running, prefill) @@ -959,6 +968,15 @@ def __send_resps(step_outputs: List[InferOutput]): await self._await_forward_event(forward_event) __send_resps(resps) + async def p2p_initialize(self, init_request: DistServeInitRequest): + return await self.engine_conn.p2p_initialize(init_request) + + def p2p_connect(self, conn_request: DistServeConnectionRequest): + return self.engine_conn.p2p_connect(conn_request) + + async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): + return self.engine_conn.p2p_drop_connect(drop_conn_request) + @torch.inference_mode() async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): """Async loop migration.""" @@ -974,16 +992,22 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event prefill_block_ids = migration_request.remote_block_ids decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) - assert len(prefill_block_ids) == len(decode_block_ids) - migration_execution_requests.append(( - migration_request.remote_engine_id, - list(zip(prefill_block_ids, decode_block_ids)), - )) - migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, - requests=migration_execution_requests) - logger.info(f'migrating session: {msg.session_id} begin') - await self.executor.migrate(migration_inputs) - logger.info(f'migrating session: {msg.session_id} done') + if not migration_request.is_dummy_prefill: + assert len(prefill_block_ids) == len(decode_block_ids), ( + f'#prefill block ids ({len(prefill_block_ids)}) must equal to ' + f'#decode block ids ({len(decode_block_ids)})' + f'all id length: {len(msg.num_token_ids)}') + migration_execution_requests.append(( + migration_request.remote_engine_id, + list(zip(prefill_block_ids, decode_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, + requests=migration_execution_requests) + logger.info(f'migrating session: {msg.session_id} begin') + await self.executor.migrate(migration_inputs) + logger.info(f'migrating session: {msg.session_id} done') + await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id, + remote_session_id=migration_request.remote_session_id) # generate output outputs: Dict[int, InferOutput] = dict() @@ -1005,7 +1029,7 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event has_runable_event.set() else: # release coroutine for decoding - await asyncio.sleep(0) + await asyncio.sleep(.5) @torch.inference_mode() async def _async_loop_main( @@ -1028,6 +1052,14 @@ async def _async_loop_main( await has_runable_event.wait() scheduler.collect_migration_done() forward_inputs, next_running = await inputs_maker.send_next_inputs() + if next_running is None: + # TODO (JimyMa): add watermark check event instead of async sleep. + # self.perfill_watermark_event.wait() + logger.warning(f'no next prefill running request, Maybe cache is full, ' + f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, ' + f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}') + await asyncio.sleep(0.1) + continue num_loops = forward_inputs['loop_count'] running = next_running next_running = None @@ -1175,14 +1207,6 @@ def end_session(self, session_id: int): return True return False - def p2p_initialize(self, conn_request): - """Init rdma link.""" - return self.executor.p2p_initialize(conn_request) - - def p2p_connect(self, conn_request): - """rdma_connect.""" - return self.executor.p2p_connect(conn_request) - def get_model_config(self): return self.model_config diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 0af69c93cc..ba345c81e7 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.engine.cache_engine import CacheEngine from lmdeploy.utils import get_logger @@ -104,7 +104,7 @@ def p2p_initialize(self, remote_engine_config: DistServeInitRequest): """Init rdma link.""" raise NotImplementedError('Not implemented') - def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + def p2p_connect(self, conn_request: List[DistServeKVTransferEndpointInfo]): """rdma_connect.""" raise NotImplementedError('Not Implemented') diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index 93870c6bda..c92ffea582 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -5,8 +5,8 @@ from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.distributed import DistContext from lmdeploy.pytorch.engine.model_agent import build_model_agent from lmdeploy.utils import get_logger @@ -174,8 +174,8 @@ def release(self): def p2p_initialize(self, init_request: DistServeInitRequest): return self.model_agent.cache_engine.p2p_initialize(init_request) - def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): - return self.model_agent.cache_engine.p2p_connect(conn_request) + def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]): + return self.model_agent.cache_engine.p2p_connect(remote_engine_id, conn_request) async def migrate(self, inputs: MigrationExecutionBatch): return await self.model_agent.cache_engine.migrate(inputs) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 0e6bd1114f..9d2da62bfa 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -16,8 +16,8 @@ from lmdeploy.pytorch.backends.selector import init_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext, get_device_manager +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.utils import get_logger, try_import_deeplink from .base import ExecutorBase @@ -649,9 +649,12 @@ def _init_ascend_distributed_environment(self, driver_ip): def p2p_initialize(self, init_request: DistServeInitRequest): return self.collective_rpc('p2p_initialize', (init_request, )) - def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + def p2p_connect(self, remote_engine_id: str, conn_request: List[DistServeKVTransferEndpointInfo]): """Rdma connect.""" - return self.collective_rpc('p2p_connect', (conn_request, )) + return self.collective_rpc('p2p_connect', ( + remote_engine_id, + conn_request, + )) async def migrate(self, batch: MigrationExecutionBatch): jobs = (worker.migrate.remote(batch) for worker in self.workers) diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index c4884d52df..bfe5816c5f 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -4,8 +4,8 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig from lmdeploy.pytorch.devices import DeviceContext +from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest from lmdeploy.pytorch.engine.model_agent import build_model_agent from lmdeploy.utils import get_logger @@ -112,7 +112,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest): """ return [self.model_agent.cache_engine.p2p_initialize(init_request)] - def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): + def p2p_connect(self, conn_request: List[DistServeKVTransferEndpointInfo]): """rdma_connect.""" self.model_agent.cache_engine.p2p_connect(conn_request) diff --git a/lmdeploy/pytorch/engine/mp_engine/mp_engine.py b/lmdeploy/pytorch/engine/mp_engine/mp_engine.py index a656680d05..f6a7b0ad25 100644 --- a/lmdeploy/pytorch/engine/mp_engine/mp_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/mp_engine.py @@ -3,14 +3,20 @@ import pickle import signal from contextlib import asynccontextmanager +from typing import TYPE_CHECKING import torch.multiprocessing as mp from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, + DistServeInitRequest) from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.engine import Engine + def cancel_async_tasks(loop: asyncio.AbstractEventLoop): """Cancel async tasks.""" @@ -168,7 +174,7 @@ def _signal_handler(signum, frame): cancel_async_tasks(loop) @staticmethod - async def _mp_proc_async(server, engine): + async def _mp_proc_async(server, engine: 'Engine'): """Mp process function.""" engine.start_loop() instance_pool = EngineInstancePool(engine) @@ -178,6 +184,7 @@ async def _mp_proc_async(server, engine): server.register_method('get_model_config', engine.get_model_config) server.register_method('p2p_initialize', engine.p2p_initialize) server.register_method('p2p_connect', engine.p2p_connect) + server.register_method('p2p_drop_connect', engine.p2p_drop_connect) server.register_method('instance_async_end', instance_pool.async_end) server.register_method('instance_async_cancel', instance_pool.async_cancel) server.register_method('instance_async_stream_infer', instance_pool.async_stream_infer) @@ -214,14 +221,22 @@ def end_session(self, session_id: int): """End session.""" return self._collective_rpc('end_session', session_id) - def p2p_initialize(self, conn_request): + def p2p_initialize(self, conn_request: DistServeInitRequest): """Init rdma link.""" return self._collective_rpc('p2p_initialize', conn_request) - def p2p_connect(self, conn_request): + def p2p_connect(self, conn_request: DistServeConnectionRequest): """rdma_connect.""" return self._collective_rpc('p2p_connect', conn_request) + def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): + """Drop connection. + + 1. drop engine connection (zmq connection) + 2. TODO(JimyMa) drop RDMA Connection. + """ + return self._collective_rpc('p2p_drop_connect', drop_conn_request) + def create_instance(self, cuda_stream_id=0): """Create instance.""" return MPEngineInstance(self) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index f447b70dc9..b73cb5df38 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,7 +8,7 @@ from torch import Tensor from lmdeploy.messages import EngineCoreEvent, EngineCoreEventType, GenerationConfig, LogitsProcessor -from lmdeploy.pytorch.disagg.request import MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index ddc740dcbc..c1967a0f07 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -723,6 +723,7 @@ def forward(self, hidden_states: torch.Tensor): batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) topk_weights, topk_ids = self.gate(hidden_states) + out_states = self.experts( hidden_states, topk_weights, diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index f25f2c07e6..13a4cecbf0 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -339,6 +339,9 @@ def has_running(self): def has_waiting(self): return self.num_waiting() > 0 + def has_to_be_migrated(self): + return self.num_to_be_migrated() > 0 + def has_migration_running(self): return self.num_running() > 0 @@ -360,6 +363,14 @@ def num_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING) + def num_to_be_migrated(self): + """Num waiting.""" + return self.seq_manager.num_sequences(MessageStatus.TO_BE_MIGRATED) + + def num_migration_locked(self): + """Num waiting.""" + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_LOCKED) + def num_migration_running(self): """Num migration running.""" return self.seq_manager.num_sequences(MessageStatus.RUNNING_MIGRATION) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 4904aa93d1..ba4e27a603 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -24,7 +24,8 @@ from lmdeploy.metrics.metrics_processor import metrics_processor from lmdeploy.metrics.stats import IterationStats, RequestState from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, + DistServeInitRequest) from lmdeploy.serve.utils import LogitsMixin from lmdeploy.tokenizer import DetokenizeState from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger @@ -811,8 +812,7 @@ def is_error(status): if outputs.last_hidden_state is not None: out.last_hidden_state = outputs.last_hidden_state if hit_stop_token: - out.last_hidden_state = \ - out.last_hidden_state[:-hit_stop_token] + out.last_hidden_state = out.last_hidden_state[:-hit_stop_token] if outputs.logits is not None: out.logits = outputs.logits if hit_stop_token: @@ -972,4 +972,7 @@ def p2p_initialize(self, init_request: DistServeInitRequest): def p2p_connect(self, conn_request: List[DistServeConnectionRequest]): return self.engine.p2p_connect(conn_request) + def p2p_drop_connect(self, drop_conn_request: List[DistServeDropConnectionRequest]): + return self.engine.p2p_drop_connect(drop_conn_request) + """ DistServe Async Engine API End """ diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f49b2319d8..205289f9d8 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -26,7 +26,9 @@ from lmdeploy.metrics.metrics_processor import metrics_processor from lmdeploy.model import ChatTemplateConfig from lmdeploy.pytorch.disagg.config import DistServeEngineConfig -from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest, MigrationRequest +from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, + DistServeDropConnectionRequest, DistServeInitRequest, + MigrationRequest) from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice, @@ -963,14 +965,18 @@ async def p2p_initialize(init_request: DistServeInitRequest): @router.post('/distserve/p2p_connect') -async def p2p_connect(conn_request: List[DistServeConnectionRequest]): +async def p2p_connect(conn_request: DistServeConnectionRequest): return VariableInterface.async_engine.p2p_connect(conn_request) +@router.post('/distserve/p2p_drop_connect') +async def p2p_drop_connect(drop_conn_request: DistServeDropConnectionRequest): + return VariableInterface.async_engine.p2p_drop_connect(drop_conn_request) + + @router.post('/distserve/free_cache') -async def free_cache(raw_request: Request) -> JSONResponse: - config = await raw_request.json() - session_id = int(config['session_id']) +async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONResponse: + session_id = cache_free_request.remote_session_id VariableInterface.async_engine.free_cache(session_id) return {'status': 'SUCCESS'} diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index f357d04f74..a1f83dbe57 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -21,11 +21,10 @@ from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field -from lmdeploy.pytorch.disagg.config import (DistServeRDMAConfig, EngineRole, MigrationProtocol, RDMALinkType, - ServingStrategy) -from lmdeploy.pytorch.disagg.conn import PDConnectionPool +from lmdeploy.pytorch.disagg.config import DistServeRDMAConfig, EngineRole, RDMALinkType, ServingStrategy +from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest +from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool from lmdeploy.pytorch.disagg.messages import PDConnectionMessage -from lmdeploy.pytorch.disagg.request import MigrationRequest from lmdeploy.serve.openai.api_server import check_api_key, create_error_response from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission @@ -109,7 +108,7 @@ def __init__(self, self.migration_protocol = MigrationProtocol[migration_protocol] self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) self.pd_connection_pool = PDConnectionPool() - self.initialized = False + self.dummy_prefill = False def get_nodes(self, role: EngineRole) -> Dict: items = list(self.nodes.items()) @@ -174,12 +173,7 @@ def remove(self, node_url: str): if node_url in self.nodes.keys(): self.nodes.pop(node_url) self.update_config_file() - dropped_conn = [] - for conn in self.pd_connection_pool.pool: - if node_url in conn: - dropped_conn.append(conn) - for conn in dropped_conn: - self.pd_connection_pool.drop(*conn) + self.pd_connection_pool.dereg_instance(node_url) def terminate_node(self, node_url: str): """Terminate a node.""" @@ -343,12 +337,7 @@ def handle_api_timeout(self, node_url): } return json.dumps(ret).encode() + b'\n' - async def stream_generate(self, - request: Dict, - node_url: str, - endpoint: str, - prefill_url: Optional[str] = None, - remote_session_id: int = None): + async def stream_generate(self, request: Dict, node_url: str, endpoint: str): """Return a generator to handle the input request. Args: @@ -362,16 +351,12 @@ async def stream_generate(self, async for line in response.content: if line.strip(): yield line + b'\n\n' - if prefill_url: - async with session.post(f'{prefill_url}/distserve/free_cache', - json={'session_id': remote_session_id}) as response: - await response.json() except (Exception, GeneratorExit, aiohttp.ClientError) as e: # noqa logger.error(f'catched an exception: {e}') # exception happened, reduce unfinished num yield self.handle_api_timeout(node_url) - async def generate(self, request: Dict, node_url: str, endpoint: str, is_prefill: bool = False): + async def generate(self, request: Dict, node_url: str, endpoint: str): """Return a the response of the input request. Args: @@ -522,6 +507,12 @@ async def connection_warmup(): return JSONResponse({'SUCCESS': True}) +@app.post('/distserve/gc') +async def cache_block_gc_to_be_migrated(): + # TODO (JimyMa): add garbage collection of to be migrated request + raise NotImplementedError + + @app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)]) async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): """Completion API similar to OpenAI's API. @@ -608,17 +599,17 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque prefill_request_dict['with_cache'] = True prefill_request_dict['preserve_cache'] = True - p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) - if not p_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Prefill request is dispatched to {p_url}') + prefill_info = {} + p_url = 'dummy:dummy' + if not node_manager.dummy_prefill: + p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) + if not p_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Prefill request is dispatched to {p_url}') - start = node_manager.pre_call(p_url) - prefill_info = json.loads(await node_manager.generate(prefill_request_dict, - p_url, - '/v1/chat/completions', - is_prefill=True)) - node_manager.post_call(p_url, start) + start = node_manager.pre_call(p_url) + prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/chat/completions')) + node_manager.post_call(p_url, start) # # Decode d_url = node_manager.get_node_url(request.model, EngineRole.Decode) @@ -626,42 +617,44 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque return node_manager.handle_unavailable_model(request.model) logger.info(f'A Decode request is dispatched to {d_url}') - if not node_manager.pd_connection_pool.is_connected(p_url, d_url): - await node_manager.pd_connection_pool.connect( - PDConnectionMessage( - p_url=p_url, - d_url=d_url, - protocol=node_manager.migration_protocol, - rdma_config=node_manager.rdma_config, - )) + if not node_manager.dummy_prefill: + if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + await node_manager.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) + + remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 + remote_block_ids = prefill_info.get('cache_block_ids') or [] + remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 + request_dict['migration_request'] = MigrationRequest( protocol=node_manager.migration_protocol, remote_engine_id=p_url, - remote_session_id=int(prefill_info['id']), - remote_block_ids=prefill_info['cache_block_ids'], - remote_token_id=prefill_info['remote_token_ids'][-1], - ).model_dump(mode='json') + remote_session_id=remote_session_id, + remote_block_ids=remote_block_ids, + remote_token_id=remote_token_id, + is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') start = node_manager.pre_call(d_url) + node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id']) if request.stream is True: - response = node_manager.stream_generate(request_dict, - d_url, - '/v1/chat/completions', - prefill_url=p_url, - remote_session_id=int(prefill_info['id'])) + response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions') background_task = node_manager.create_background_tasks(d_url, start) - return StreamingResponse(response, background=background_task) + resp = StreamingResponse(response, background=background_task) else: - try: - response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions') - node_manager.post_call(d_url, start) - resp = JSONResponse(json.loads(response)) - finally: - async with aiohttp.ClientSession() as session: - async with session.post(f'{p_url}/distserve/free_cache', json={'session_id': - prefill_info['id']}) as response: - await response.json() - return resp + response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions') + node_manager.post_call(d_url, start) + resp = JSONResponse(json.loads(response)) + + if not node_manager.dummy_prefill: + node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id']) + + return resp + else: raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') @@ -732,59 +725,74 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None prefill_request_dict['with_cache'] = True prefill_request_dict['preserve_cache'] = True - p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) - if not p_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Prefill request is dispatched to {p_url}') + if not node_manager.dummy_prefill: + try: + p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) + except Exception as e: + logger.error(f'error Msg: {str(e)}') + return {'status': 'Instance sch error, cannot find available p_url'} + + if not p_url: + return node_manager.handle_unavailable_model(request.model) + logger.info(f'A Prefill request is dispatched to {p_url}') + + start = node_manager.pre_call(p_url) + prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/completions')) + node_manager.post_call(p_url, start) + else: + p_url = 'dummy:dummy' + prefill_info = {} - start = node_manager.pre_call(p_url) - prefill_info = json.loads(await node_manager.generate(prefill_request_dict, - p_url, - '/v1/completions', - is_prefill=True)) - node_manager.post_call(p_url, start) + # Decode + try: + d_url = node_manager.get_node_url(request.model, EngineRole.Decode) + except Exception as e: + logger.error(f'error Msg: {str(e)}') + return {'status': 'Instance sch error, cannot find available p_url'} - # # Decode - d_url = node_manager.get_node_url(request.model, EngineRole.Decode) if not d_url: return node_manager.handle_unavailable_model(request.model) logger.info(f'A Decode request is dispatched to {d_url}') - if not node_manager.pd_connection_pool.is_connected(p_url, d_url): - await node_manager.pd_connection_pool.connect( - PDConnectionMessage( - p_url=p_url, - d_url=d_url, - protocol=node_manager.migration_protocol, - rdma_config=node_manager.rdma_config, - )) - + if not node_manager.dummy_prefill: + if not node_manager.pd_connection_pool.is_connected(p_url, d_url): + try: + await node_manager.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=node_manager.migration_protocol, + rdma_config=node_manager.rdma_config, + )) + except Exception as e: + logger.error(f'error Msg: {str(e)}') + return {'status': f'Connection error, cannot establish connection {(p_url, d_url)}'} + node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id']) + + remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 + remote_block_ids = prefill_info.get('cache_block_ids') or [] + remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 request_dict['migration_request'] = MigrationRequest( protocol=node_manager.migration_protocol, remote_engine_id=p_url, - remote_session_id=int(prefill_info['id']), - remote_block_ids=prefill_info['cache_block_ids'], - remote_token_id=prefill_info['remote_token_ids'][-1], - ).model_dump(mode='json') + remote_session_id=remote_session_id, + remote_block_ids=remote_block_ids, + remote_token_id=remote_token_id, + is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') start = node_manager.pre_call(d_url) if request.stream is True: - response = node_manager.stream_generate(request_dict, - d_url, - '/v1/completions', - prefill_url=p_url, - remote_session_id=int(prefill_info['id'])) + response = node_manager.stream_generate(request_dict, d_url, '/v1/completions') background_task = node_manager.create_background_tasks(d_url, start) - return StreamingResponse(response, background=background_task) + resp = StreamingResponse(response, background=background_task) else: response = await node_manager.generate(request_dict, d_url, '/v1/completions') node_manager.post_call(d_url, start) + node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info.get('id')) resp = JSONResponse(json.loads(response)) - async with aiohttp.ClientSession() as session: - async with session.post(f'{p_url}/distserve/free_cache', json={'session_id': - prefill_info['id']}) as response: - await response.json() - return resp + if not node_manager.dummy_prefill: + node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info.get('id')) + return resp else: raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') @@ -799,6 +807,7 @@ def proxy(server_name: str = '0.0.0.0', disable_cache_status: bool = False, link_type: Literal['RoCE', 'IB'] = 'RoCE', migration_protocol: Literal['RDMA'] = 'RDMA', + dummy_prefill: bool = False, **kwargs): """To launch the proxy server. @@ -821,6 +830,7 @@ def proxy(server_name: str = '0.0.0.0', node_manager.serving_strategy = ServingStrategy[serving_strategy] node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy) node_manager.migration_protocol = MigrationProtocol[migration_protocol] + node_manager.dummy_prefill = dummy_prefill node_manager.rdma_config = DistServeRDMAConfig( link_type=RDMALinkType[link_type],