Skip to content

fix free cache in MPEngine branch #3670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.profiler import Profiler, Session
from lmdeploy.pytorch.engine import EngineInstance
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -142,7 +141,8 @@ def __init__(self, model_path: str, engine_config: Union[PytorchEngineConfig, Tu
tm_model = TurboMind.from_pretrained(model_path, tokenizer=self.tokenizer, engine_config=engine_config)
elif isinstance(engine_config, PytorchEngineConfig):
from lmdeploy.pytorch.engine import Engine as PytorchEngine
tm_model = PytorchEngine(model_path, tokenizer=self.tokenizer, engine_config=engine_config)
tm_model = PytorchEngine.from_pretrained(model_path, tokenizer=self.tokenizer, engine_config=engine_config)

self.tm_model = tm_model
self.pbar = None

Expand Down Expand Up @@ -190,7 +190,7 @@ async def _inference(self, req_queue: Queue, session_id: int, temperature: float
await generator.aclose()

# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
if hasattr(model_inst, '_is_pytorch_engine'):
await model_inst.async_end(session_id)

self.pbar.update(1)
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass

from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.disagg.request import MigrationRequest
from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest

from .tokenizer import Tokenizer
from .utils import get_logger
Expand Down Expand Up @@ -318,6 +318,7 @@ class PytorchEngineConfig:
'Decode']. Default to `EngineRole.Hybrid`.
migration_backend: migration backend. options: ['DLSlime'].
Default to `MigrationBackend.DLSlime`.
enable_mp_engine (bool): run engine in multi-process mode.
model_format (str): weight quantization policy, options: ['fp8'].
"""
dtype: str = 'auto'
Expand Down Expand Up @@ -346,6 +347,7 @@ class PytorchEngineConfig:
empty_init: bool = False
enable_microbatch: bool = False
enable_eplb: bool = False
enable_mp_engine: bool = False
model_format: str = None

role: EngineRole = EngineRole.Hybrid
Expand Down
17 changes: 6 additions & 11 deletions lmdeploy/pytorch/disagg/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.logger import get_logger

logger = get_logger('lmdeploy')
logger = get_logger("lmdeploy")

try:
logger.debug('Registering DLSlime Backend')
logger.debug("Registering DLSlime Backend")
from .dlslime import DLSlimeBackend
except ImportError:
logger.warning('Disable DLSlime Backend')
logger.warning("Disable DLSlime Backend")

try:
logger.debug('Registering Mooncake Backend')
logger.debug("Registering Mooncake Backend")
from .mooncake import MooncakeBackend
except ImportError:
logger.warning('Disable Mooncake Backend')
logger.warning("Disable Mooncake Backend")

try:
logger.debug('Registering InfiniStoreBackend Backend')
from .infinistore import InfiniStoreBackend
except ImportError:
logger.warning('Disable InfiniStoreBackend Backend')

__all__ = ['DLSlimeBackend', 'MooncakeBackend', 'InfiniStoreBackend']
__all__ = ["DLSlimeBackend", "MooncakeBackend"]
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/disagg/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod

from lmdeploy.pytorch.disagg.config import MigrationProtocol
from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest


class MigrationBackendImpl:
Expand All @@ -21,7 +21,7 @@ def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
return NotImplementedError

@abstractmethod
def p2p_connect(self, conn_req: DistServeConnectionRequest):
def p2p_connect(self, remote_engine_id:str, conn_req: DistServeKVTransferEndpointInfo):
raise NotImplementedError

@abstractmethod
Expand Down
12 changes: 6 additions & 6 deletions lmdeploy/pytorch/disagg/backend/dlslime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from lmdeploy.logger import get_logger
from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend, MigrationProtocol
from lmdeploy.pytorch.disagg.config import DistServeEngineConfig, MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, MigrationProtocol, DistServeKVTransferEndpointInfo
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest

logger = get_logger('lmdeploy')

Expand Down Expand Up @@ -60,8 +60,8 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
register_mr_request.offset,
register_mr_request.length)

def connect(self, connect_request: DistServeConnectionRequest):
self.endpoint[connect_request.protocol].connect(json.loads(connect_request.remote_endpoint_info))
def connect(self, kvtransfer_endpoint_info: DistServeKVTransferEndpointInfo):
self.endpoint[kvtransfer_endpoint_info.protocol].connect(json.loads(kvtransfer_endpoint_info.endpoint_info))

async def p2p_migrate(self, assignment: MigrationAssignment, async_op=False):
batch = [
Expand Down Expand Up @@ -104,8 +104,8 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
return self.links[remote_engine_id].endpoint[protocol].endpoint_info

def p2p_connect(self, conn_req: DistServeConnectionRequest):
self.links[conn_req.remote_engine_id].connect(conn_req)
def p2p_connect(self, remote_engine_id:str, conn_req: DistServeKVTransferEndpointInfo):
self.links[remote_engine_id].connect(conn_req)

async def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
await self.links[assignment.remote_engine_id].p2p_migrate(assignment, async_op=async_op)
Expand Down
31 changes: 0 additions & 31 deletions lmdeploy/pytorch/disagg/backend/infinistore.py

This file was deleted.

6 changes: 3 additions & 3 deletions lmdeploy/pytorch/disagg/backend/mooncake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from lmdeploy.pytorch.disagg.backend.backend import MIGRATION_BACKENDS
from lmdeploy.pytorch.disagg.backend.base import MigrationBackendImpl
from lmdeploy.pytorch.disagg.config import MigrationBackend, MigrationProtocol
from lmdeploy.pytorch.disagg.config import MigrationBackend
from lmdeploy.pytorch.disagg.conn.protocol import DistServeInitRequest, DistServeKVTransferEndpointInfo, MigrationProtocol
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest


@MIGRATION_BACKENDS.register_module(MigrationBackend.Mooncake.name)
Expand All @@ -18,7 +18,7 @@ def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage
def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
return NotImplementedError

def p2p_connect(self, connect_request: DistServeConnectionRequest):
def p2p_connect(self, remote_engine_id:str, conn_req: DistServeKVTransferEndpointInfo):
raise NotImplementedError

def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
Expand Down
26 changes: 0 additions & 26 deletions lmdeploy/pytorch/disagg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,6 @@ class MigrationBackend(enum.Enum):
InfiniStore = enum.auto()


class MigrationProtocol(enum.Enum):
"""Migration Transport Protocol.

Attributes:
TCP: TCP for General Purpose Transport Protocol.
RDMA: IB or RoCEv1/v2.
NVLINK: High device-to-device link.

Warning: By now, only `GPU Directed RDMA` is supported in DistServe.
We preserve several protocol and will be implemented in the future.
"""

TCP = enum.auto()
RDMA = enum.auto()
NVLINK = enum.auto()


class RDMALinkType(enum.Enum):
"""RDMA Link Type."""

Expand Down Expand Up @@ -126,12 +109,3 @@ class DistServeEngineConfig(BaseModel):
num_cpu_blocks: int
num_gpu_blocks: int


class DistServeConfig(BaseModel):
"""DistServe Config."""

serving_strategy: ServingStrategy
distserve_transport_protocol: MigrationProtocol
rdma_config: Optional[DistServeRDMAConfig] = None
nvlink_config: Optional[DistServeNVLinkConfig] = None
tcp_config: Optional[DistServeTCPConfig] = None
116 changes: 116 additions & 0 deletions lmdeploy/pytorch/disagg/conn/engine_conn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import time
import asyncio
from dataclasses import dataclass
from typing import Dict, List, TYPE_CHECKING
from urllib.parse import urlparse

import zmq
import zmq.asyncio

from pydantic import BaseModel

from lmdeploy.logger import get_logger
from lmdeploy.pytorch.disagg.conn.protocol import (
DistServeConnectionStatus,
DistServeConnectionRequest,
DistServeConnectionResponse,
DistServeEngineEndpointInfo,
DistServeKVTransferEndpointInfo,
DistServeInitRequest,
DistServeInitResponse,
DistServeCacheFreeRequest,
)
from lmdeploy.pytorch.messages import MessageStatus
from lmdeploy.pytorch.engine.executor.dist_utils import find_available_port

if TYPE_CHECKING:
from lmdeploy.pytorch.engine.engine import Engine


logger = get_logger("lmdeploy")


class EngineP2PConnection:
def __init__(self, engine: "Engine"):
self.engine: Engine = engine
self.p2p_conn_ctx: Dict[str, zmq.asyncio.Context] = {}
self.p2p_sender: Dict[str, zmq.asyncio.Socket] = {}
self.p2p_receiver: Dict[str, zmq.asyncio.Socket] = {}

self.use_unique_kvtransfer_engine = os.environ.get("LMDEPLOY_USE_UNIQUE_KVTRANSFER_ENGINE", False)

async def p2p_initialize(self, init_request: DistServeInitRequest):
ctx = zmq.asyncio.Context(2)
sender = ctx.socket(zmq.PUSH)
sender_port = find_available_port()
sender_hostname = urlparse(init_request.local_engine_id).hostname
zmq_address = f"tcp://{sender_hostname}:{sender_port}"
sender.bind(zmq_address)
receiver = ctx.socket(zmq.PULL)

self.p2p_conn_ctx[init_request.remote_engine_id] = ctx
self.p2p_sender[init_request.remote_engine_id] = sender
self.p2p_receiver[init_request.remote_engine_id] = receiver

kvtransfer_endpoint_info: List[DistServeKVTransferEndpointInfo] = self.engine.executor.p2p_initialize(init_request)

return DistServeInitResponse(
engine_endpoint_info=DistServeEngineEndpointInfo(zmq_address=zmq_address),
kvtransfer_endpoint_info=kvtransfer_endpoint_info,
status=DistServeConnectionStatus.SUCCESS
)

def p2p_connect(self, conn_request: DistServeConnectionRequest):
self.p2p_receiver[conn_request.remote_engine_id].connect(
conn_request.remote_engine_endpoint_info.zmq_address
)
self.engine.executor.p2p_connect(remote_engine_id=conn_request.remote_engine_id, conn_request=conn_request.remote_kvtransfer_endpoint_info)
event_loop = asyncio.get_event_loop()
event_loop.create_task(self.handle_zmq_rercv(conn_request.remote_engine_id))
return DistServeConnectionResponse(status=DistServeConnectionStatus.SUCCESS)

async def zmq_send(self, remote_engine_id: str, remote_session_id: int):
await self.p2p_sender[remote_engine_id].send_pyobj(
DistServeCacheFreeRequest(remote_engine_id=remote_engine_id, remote_session_id=remote_session_id)
)

async def handle_zmq_rercv(self, remote_engine_id: str):
while True:
req: DistServeCacheFreeRequest = await self.p2p_receiver[
remote_engine_id
].recv_pyobj()
if isinstance(req, DistServeCacheFreeRequest):
session_id = req.remote_session_id
if session_id in self.engine.scheduler.sessions:
self.engine.scheduler.end_session(session_id=session_id)
logger.error({
'scheduling type': 'free',
'time': time.time(),
'dp_rank': self.engine.engine_config.dp_rank,
'role': self.engine.scheduler.cache_config.role,
'max batches': self.engine.scheduler.scheduler_config.max_batches,
'total_waiting': self.engine.scheduler.num_waiting(),
'total_running': self.engine.scheduler.num_running(),
'total_locking': self.engine.scheduler.num_locked(),
'total_to_be_migrated': self.engine.scheduler.num_to_be_migrated(),
'total_migration_waiting': self.engine.scheduler.num_migration_waiting(),
'total_migration_running': self.engine.scheduler.num_migration_running(),
'total_migration_locked': self.engine.scheduler.num_migration_locked(),
'total_stopped': self.engine.scheduler.seq_manager.num_sequences(MessageStatus.STOPPED),
'kv_usage': (
self.engine.scheduler.block_manager.get_num_free_gpu_blocks(),
self.engine.scheduler.block_manager.num_gpu_blocks,
),
})
else:
logger.error(f"invalid free, {remote_engine_id}, {session_id}")
else:
raise ValueError(f"Unsupported zmq request {type(req)}")

async def p2p_disconnect(self, remote_engine_id: str):
self.p2p_receiver[remote_engine_id].close()
self.p2p_sender[remote_engine_id].close()
self.p2p_conn_ctx[remote_engine_id].term()
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/disagg/conn/kvtransfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.


class KVTransferEngine:
pass

Loading