From c1f3a6c5dac78056013fdff27d0f3dc745c019b8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 22 Sep 2025 21:45:53 +0800 Subject: [PATCH 01/15] WIP --- lmdeploy/pytorch/backends/awq_modules.py | 7 +- .../pytorch/backends/blockedf8_modules.py | 2 + lmdeploy/pytorch/backends/cuda/attention.py | 2 +- lmdeploy/pytorch/backends/cuda/awq_modules.py | 5 +- .../backends/cuda/blockedf8_modules.py | 9 +- .../pytorch/backends/cuda/graph_runner.py | 2 +- lmdeploy/pytorch/backends/cuda/qmodules.py | 5 +- .../pytorch/backends/default/awq_modules.py | 5 +- lmdeploy/pytorch/backends/default/linear.py | 15 +- .../pytorch/backends/dlinfer/awq_modules.py | 3 +- lmdeploy/pytorch/backends/dlinfer/linear.py | 5 +- lmdeploy/pytorch/backends/dlinfer/qmodules.py | 5 +- lmdeploy/pytorch/backends/linear.py | 2 + lmdeploy/pytorch/backends/qmodules.py | 3 +- lmdeploy/pytorch/config.py | 70 +++- lmdeploy/pytorch/distributed.py | 342 ++++++++++++------ lmdeploy/pytorch/engine/executor/base.py | 3 +- .../pytorch/engine/executor/ray_executor.py | 25 +- lmdeploy/pytorch/engine/model_agent.py | 60 +-- lmdeploy/pytorch/model_inputs.py | 35 +- lmdeploy/pytorch/models/deepseek_v2.py | 28 +- lmdeploy/pytorch/models/llama4.py | 4 +- lmdeploy/pytorch/models/qwen3_moe.py | 6 +- lmdeploy/pytorch/models/sdar_moe.py | 5 +- lmdeploy/pytorch/nn/attention.py | 8 +- lmdeploy/pytorch/nn/linear/__init__.py | 175 ++++----- lmdeploy/pytorch/nn/linear/awq.py | 45 ++- lmdeploy/pytorch/nn/linear/base.py | 61 +++- lmdeploy/pytorch/nn/linear/blocked_fp8.py | 48 ++- lmdeploy/pytorch/nn/linear/default.py | 55 ++- lmdeploy/pytorch/nn/linear/utils.py | 22 +- lmdeploy/pytorch/nn/linear/w8a8.py | 40 +- lmdeploy/pytorch/nn/moe.py | 32 +- lmdeploy/pytorch/nn/norm.py | 4 +- 34 files changed, 684 insertions(+), 454 deletions(-) diff --git a/lmdeploy/pytorch/backends/awq_modules.py b/lmdeploy/pytorch/backends/awq_modules.py index bfd5372ba3..1a9815c423 100644 --- a/lmdeploy/pytorch/backends/awq_modules.py +++ b/lmdeploy/pytorch/backends/awq_modules.py @@ -17,7 +17,12 @@ def update_weights(self, return qweight, scales, qzeros, bias @abstractmethod - def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py index 6750654a79..b30bb56fae 100644 --- a/lmdeploy/pytorch/backends/blockedf8_modules.py +++ b/lmdeploy/pytorch/backends/blockedf8_modules.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch +import torch.distributed as dist class LinearBlockedF8Impl(ABC): @@ -19,6 +20,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index b241c384b2..b896b0579e 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -90,7 +90,7 @@ def __init__( self.flash_attention_fwd = flash_attention_fwd # for alibi attention - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') self.alibi_head_offset = self.num_heads * rank self.alibi_num_heads = self.num_heads * world_size self.block_sparse_size = block_sparse_size diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index be23e06e79..01516a8aca 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -55,12 +55,13 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out_features = scales.size(1) out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py index 25d69806b3..bb70a9f718 100644 --- a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py +++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py @@ -13,12 +13,12 @@ logger = get_logger('lmdeploy') -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]): +def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: Optional[dist.ProcessGroup] = None): """Reduce scatter.""" outs = out.split(tp_sizes, -2) out = outs[rank] outs = list(outs) - dist.reduce_scatter(out, outs) + dist.reduce_scatter(out, outs, group=group) return out @@ -117,6 +117,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" @@ -131,9 +132,9 @@ def forward(self, if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + out = _reduce_scatter_input(out, rank, scatter_size, group=group) else: - dist.all_reduce(out) + dist.all_reduce(out, group=group) out = out.unflatten(0, x_shape[:-1]) return out diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..536ab65e00 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -258,7 +258,7 @@ def update_inputs(self, inputs): meta = self.get_meta() padding_batch_size = meta.padding_batch_size tp_size = self._get_capture_tokens(padding_batch_size) - dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes) + dp_meta.sync_tp_size(tp_size) return inputs def get_capture_batch_sizes(self) -> List[int]: diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index 9f2576c3f8..dc61787731 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -63,7 +63,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" if isinstance(x, torch.Tensor): input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype) @@ -79,7 +80,7 @@ def forward(self, bias=bias) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index 5c80e0e327..d2253920fa 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -62,7 +62,8 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out_shape = x.shape[:-1] + (self.out_features, ) input_dtype = x.dtype @@ -77,7 +78,7 @@ def forward(self, if input_dtype != torch.float16: out = out.to(dtype=input_dtype) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index b223b498ae..6e4870b0e5 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -2,22 +2,20 @@ from typing import List, Optional import torch +import torch.distributed as dist import torch.nn.functional as F -import lmdeploy.pytorch.distributed as dist - from ..linear import LinearBuilder, LinearImpl -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]): +def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup = None): """Reduce scatter.""" out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() + out = out.contiguous() outs = out.split(tp_sizes, 0) out = outs[rank] outs = list(outs) - dist.reduce_scatter(out, outs) + dist.reduce_scatter(out, outs, group=group) out = out.transpose(0, -2) return out @@ -30,15 +28,16 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" out = F.linear(x, weight, bias) if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + out = _reduce_scatter_input(out, rank, scatter_size, group=group) else: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/awq_modules.py b/lmdeploy/pytorch/backends/dlinfer/awq_modules.py index 8dcf750478..1ec8bf0072 100644 --- a/lmdeploy/pytorch/backends/dlinfer/awq_modules.py +++ b/lmdeploy/pytorch/backends/dlinfer/awq_modules.py @@ -23,7 +23,8 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py index 327eeae56b..ec682bba8b 100644 --- a/lmdeploy/pytorch/backends/dlinfer/linear.py +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -3,8 +3,8 @@ from typing import List, Optional import torch +import torch.distributed as dist -import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.kernels.dlinfer import linear from ..linear import LinearBuilder, LinearImpl @@ -32,12 +32,13 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" out = linear(x, weight, bias, False) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py index bdd5259fd0..fe52dd5f35 100644 --- a/lmdeploy/pytorch/backends/dlinfer/qmodules.py +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -36,7 +36,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" if isinstance(x, torch.Tensor): input_quant, input_scale = dynamic_quant(x, self.quant_dtype) @@ -46,7 +47,7 @@ def forward(self, out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index d577dcba54..740b4b7ecc 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch +import torch.distributed as dist class LinearImpl(ABC): @@ -18,6 +19,7 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index 7cb485888b..7173fb5f34 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -47,7 +47,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..7e73ed7fb8 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -100,35 +100,70 @@ def __post_init__(self): self.enable_prefix_caching = False +class TPMode(enum.Enum): + """TP Mode.""" + DEFAULT = enum.auto() + DP_TP = enum.auto() + + @dataclass class DistConfig: dp: int = 1 - tp: int = 1 ep: int = 1 dp_rank: int = 0 enable_microbatch: bool = False enable_eplb: bool = False - world_size: int = None - attn_config: 'DistConfig' = None + world_size: int = 1 + + # tp + tp: int = 1 # default tp, equal to attn_tp + attn_tp: int = 1 # tp for attention + mlp_tp: int = 1 # tp for mlp + moe_tp: int = 1 # tp for moe + + # tp mode + mlp_tp_mode: TPMode = TPMode.DEFAULT + moe_tp_mode: TPMode = TPMode.DEFAULT def __post_init__(self): """Post init.""" assert self.dp_rank < self.dp assert self.dp >= 1 - if self.dp == 1: - world_size = max(self.tp, self.ep) - attn_config = self - else: - world_size = self.dp - attn_config = DistConfig(dp=1, tp=1, ep=1, dp_rank=0) + + dp = self.dp + tp = self.tp + ep = self.ep + + # world_size + world_size = ep if ep > 1 else tp + assert world_size >= dp and world_size % dp == 0 + assert world_size >= tp and world_size % tp == 0 + assert world_size >= ep and world_size % ep == 0 self.world_size = world_size - self.attn_config = attn_config - def need_dummy_batch(self): - """Need dummy batch.""" - if self.dp == 1: - return False - return self.tp > 1 or self.ep > 1 + # tp + self.attn_tp = self.world_size // dp + self.mlp_tp = tp + self.moe_tp = 1 if ep > 1 else tp + self.tp = self.attn_tp + + # tp mode + self.mlp_tp_mode = TPMode.DEFAULT if self.attn_tp == self.mlp_tp else TPMode.DP_TP + self.moe_tp_mode = TPMode.DEFAULT if self.attn_tp == self.moe_tp else TPMode.DP_TP + + def get_tp_by_layer(self, layer_type: str): + """Get tp by layer type.""" + if layer_type == 'attn': + return self.attn_tp, TPMode.DEFAULT + elif layer_type == 'mlp': + return self.mlp_tp, self.mlp_tp_mode + elif layer_type == 'moe': + return self.moe_tp, self.moe_tp_mode + elif layer_type is None: + # for some layer that we don't need tp + return 1, TPMode.DEFAULT + else: + raise ValueError(f'Unknown layer type: {layer_type}') def _override_hf_config_dict(hf_config: dict, key: str, hf_overrides): @@ -261,10 +296,7 @@ def from_hf_config(cls, from lmdeploy.pytorch.configurations import AutoModelConfigBuilder if dist_config is None: dist_config = DistConfig() - if dist_config.dp == 1: - tp = dist_config.tp - else: - tp = 1 + tp = dist_config.attn_tp model_config = AutoModelConfigBuilder.build(hf_config, model_path, tp=tp) diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 46b89a1b04..b17a44a802 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -2,125 +2,245 @@ import threading from contextlib import contextmanager from dataclasses import dataclass -from typing import List +from datetime import timedelta +from typing import List, Optional from torch import distributed as dist -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp # noqa: F401 from .config import DistConfig +@dataclass +class DistGroup: + """Distributed group.""" + rank: int = 0 + cpu_group: dist.ProcessGroup = None + gpu_group: dist.ProcessGroup = None + cpu_groups: List[dist.ProcessGroup] = None + gpu_groups: List[dist.ProcessGroup] = None + + def close(self): + """Close groups.""" + if not dist.is_initialized(): + return + if self.cpu_groups is not None: + for group in self.cpu_groups: + dist.destroy_process_group(group) + self.cpu_groups = None + if self.gpu_groups is not None: + for group in self.gpu_groups: + dist.destroy_process_group(group) + self.gpu_groups = None + + +def _build_tp_group_impl(tp: int, + rank: int, + world_size: int, + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build tp group.""" + assert tp > 1 + tp_rank = rank % tp + tp_group_id = rank // tp + ranks = range(world_size) + tp_gpu_groups = [] + tp_cpu_groups = [] + for start in range(0, world_size, tp): + tp_ranks = ranks[start:start + tp] + group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend) + tp_gpu_groups.append(group) + cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend) + tp_cpu_groups.append(cpu_group) + tp_gpu_group = tp_gpu_groups[tp_group_id] + tp_cpu_group = tp_cpu_groups[tp_group_id] + + return DistGroup( + rank=tp_rank, + cpu_group=tp_cpu_group, + gpu_group=tp_gpu_group, + cpu_groups=tp_cpu_groups, + gpu_groups=tp_gpu_groups, + ) + + +def _build_attn_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build attention tp group.""" + dist_config = context.dist_config + tp = dist_config.attn_tp + # skip if tp == 1 + if tp == 1: + context.attn_tp_group = DistGroup(rank=0) + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + ) + context.attn_tp_group = dist_group + + +def _build_mlp_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build mlp tp group.""" + dist_config = context.dist_config + tp = dist_config.mlp_tp + # skip if tp == 1 + if tp == 1: + context.mlp_tp_group = DistGroup(rank=0) + return + + # reuse attn tp group + if tp == dist_config.attn_tp: + context.mlp_tp_group = context.attn_tp_group + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + ) + context.mlp_tp_group = dist_group + + +def _build_moe_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build moe tp group.""" + dist_config = context.dist_config + tp = dist_config.moe_tp + # skip if tp == 1 + if tp == 1: + context.moe_tp_group = DistGroup(rank=0) + return + + # reuse attn tp group + if tp == dist_config.attn_tp: + context.moe_tp_group = context.attn_tp_group + return + + # reuse mlp tp group + if tp == dist_config.mlp_tp: + context.moe_tp_group = context.mlp_tp_group + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + ) + context.moe_tp_group = dist_group + + +def _build_tp_group(context: 'DistContext', timeout: timedelta, cpu_backend: str = 'gloo', ccl_backend: str = 'nccl'): + """Build tp group.""" + _build_attn_tp_group(context, timeout, cpu_backend, ccl_backend) + _build_mlp_tp_group(context, timeout, cpu_backend, ccl_backend) + _build_moe_tp_group(context, timeout, cpu_backend, ccl_backend) + context.tp_group = context.attn_tp_group + + @dataclass class DistContext: rank: int = 0 - world_size: int = 1 - tp: int = 1 - dp: int = 1 - ep: int = 1 - tp_rank: int = 0 dp_rank: int = 0 ep_rank: int = 0 - world_cpu_group: dist.ProcessGroup = None - tp_cpu_group: dist.ProcessGroup = None - tp_gpu_group: dist.ProcessGroup = None - tp_gpu_groups: List[dist.ProcessGroup] = None - dp_cpu_group: dist.ProcessGroup = None - dp_gpu_group: dist.ProcessGroup = None + + tp_group: DistGroup = None + attn_tp_group: DistGroup = None + mlp_tp_group: DistGroup = None + moe_tp_group: DistGroup = None + ep_gpu_group: dist.ProcessGroup = None ep_gpu_groups: List[dist.ProcessGroup] = None dist_config: DistConfig = None + @classmethod + def _build_ep_group(cls, context: 'DistContext', timeout: timedelta, ccl_backend: str = 'nccl'): + """Build ep group.""" + dist_config = context.dist_config + ep = dist_config.ep + if ep <= 1: + return + + dp_rank = context.dp_rank + world_size = dist_config.world_size + ep_rank = context.rank % ep + ep_group_id = dp_rank // ep + ranks = range(world_size) + ep_gpu_groups = [] + for start in range(0, world_size, ep): + ep_ranks = ranks[start:start + ep] + group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend) + ep_gpu_groups.append(group) + ep_gpu_group = ep_gpu_groups[ep_group_id] + + context.ep_rank = ep_rank + context.ep_gpu_group = ep_gpu_group + context.ep_gpu_groups = ep_gpu_groups + @classmethod def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'): """Build dist context.""" - from datetime import timedelta timeout = timedelta(days=35600) cpu_backend = 'gloo' if dist_config is None: dist_config = DistConfig() - tp = dist_config.tp - dp = dist_config.dp - ep = dist_config.ep - world_size = dist_config.world_size - dp_rank = dist_config.dp_rank + dp_rank = dist_config.dp_rank + world_size = dist_config.world_size + context = DistContext(rank=rank, + dp_rank=dp_rank, + dist_config=dist_config, + attn_tp_group=DistGroup(rank=0), + mlp_tp_group=DistGroup(rank=0), + moe_tp_group=DistGroup(rank=0), + tp_group=DistGroup(rank=0)) if world_size == 1: - return DistContext(dist_config=dist_config) + return context assert dist.is_initialized() - # world(assume world group is gloo) - world_cpu_group = dist.GroupMember.WORLD - - tp_rank = rank % tp # tp - tp_gpu_group = None - tp_gpu_groups = None - tp_cpu_group = None - tp_group_id = dp_rank // tp - if tp > 1: - # all tp groups should be created in all procs - ranks = range(world_size) - tp_gpu_groups = [] - tp_cpu_groups = [] - for start in range(0, world_size, tp): - tp_ranks = ranks[start:start + tp] - group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend) - tp_gpu_groups.append(group) - cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend) - tp_cpu_groups.append(cpu_group) - tp_gpu_group = tp_gpu_groups[tp_group_id] - tp_cpu_group = tp_cpu_groups[tp_group_id] - - ep_rank = rank % ep - ep_gpu_group = None - ep_gpu_groups = None - ep_group_id = dp_rank // ep - if ep > 1: - ranks = range(world_size) - ep_gpu_groups = [] - for start in range(0, world_size, ep): - ep_ranks = ranks[start:start + ep] - group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend) - ep_gpu_groups.append(group) - ep_gpu_group = ep_gpu_groups[ep_group_id] - - dp_cpu_group = None - if dp > 1: - dp_cpu_group = dist.new_group(ranks=range(dp), timeout=timeout, backend=cpu_backend) - - context = DistContext( - rank=rank, - world_size=world_size, - tp=tp, - dp=dp, - ep=ep, - tp_rank=tp_rank, - dp_rank=dp_rank, - ep_rank=ep_rank, - world_cpu_group=world_cpu_group, - tp_cpu_group=tp_cpu_group, - tp_gpu_group=tp_gpu_group, - tp_gpu_groups=tp_gpu_groups, - dp_cpu_group=dp_cpu_group, - dp_gpu_group=None, - ep_gpu_group=ep_gpu_group, - ep_gpu_groups=ep_gpu_groups, - dist_config=dist_config, - ) + _build_tp_group(context, timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend) + + # ep + cls._build_ep_group(context, timeout, ccl_backend=ccl_backend) + return context def close(self): """Close groups.""" if not dist.is_initialized(): return - if self.tp_gpu_groups is not None: - for group in self.tp_gpu_groups: - dist.destroy_process_group(group) + if self.attn_tp_group is not None: + self.attn_tp_group.close() + if self.mlp_tp_group is not None: + self.mlp_tp_group.close() + if self.moe_tp_group is not None: + self.moe_tp_group.close() if self.ep_gpu_groups is not None: for group in self.ep_gpu_groups: dist.destroy_process_group(group) + self.ep_gpu_groups = None DefaultContext = DistContext.build() @@ -141,6 +261,10 @@ def set_context(self, context: DistContext): """Set current context.""" self.t_local.device_context = context + def current_config(self) -> DistConfig: + """Get current dist config.""" + return self.current_context().dist_config + @contextmanager def context(self, context: DistContext): """Context manager.""" @@ -164,25 +288,34 @@ def get_dist_manager(): def get_world_rank(): """Get distributed world size and rank.""" ctx = get_dist_manager().current_context() - world_size = ctx.world_size + world_size = ctx.dist_config.world_size rank = ctx.rank return world_size, rank -def get_tp_world_rank(): +def get_tp_world_rank(layer_type: Optional[str] = None): ctx = get_dist_manager().current_context() - return ctx.tp, ctx.tp_rank + if layer_type is None: + return ctx.dist_config.tp, ctx.tp_group.rank + elif layer_type == 'attn': + return ctx.dist_config.attn_tp, ctx.attn_tp_group.rank + elif layer_type == 'mlp': + return ctx.dist_config.mlp_tp, ctx.mlp_tp_group.rank + elif layer_type == 'moe': + return ctx.dist_config.moe_tp, ctx.moe_tp_group.rank + else: + raise RuntimeError(f'Unknown layer type: {layer_type}') def get_dp_world_rank(): ctx = get_dist_manager().current_context() - return ctx.dp, ctx.dp_rank + return ctx.dist_config.dp, ctx.dp_rank def get_ep_world_rank(): ctx = get_dist_manager().current_context() - return ctx.ep, ctx.ep_rank + return ctx.dist_config.ep, ctx.ep_rank def _check_group_device(device: str): @@ -193,48 +326,37 @@ def _check_group_device(device: str): def get_process_group(device: str = None): """Get process group.""" - ctx = get_dist_manager().current_context() - if device is None: - return dist.GroupMember.WORLD - - _check_group_device(device) - - if device == 'cpu': - return ctx.world_cpu_group - else: - raise RuntimeError('gpu world group is not supported.') + return dist.GroupMember.WORLD -def get_tp_group(device: str = 'gpu'): +def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'): """Get tp group.""" ctx = get_dist_manager().current_context() _check_group_device(device) - if device == 'cpu': - return ctx.tp_cpu_group + if layer_type == 'attn': + tp_group = ctx.attn_tp_group + elif layer_type == 'mlp': + tp_group = ctx.mlp_tp_group + elif layer_type == 'moe': + tp_group = ctx.moe_tp_group else: - return ctx.tp_gpu_group - + raise RuntimeError(f'Unknown layer type: {layer_type}') -def get_dp_group(device: str = 'gpu'): - """Get dp group.""" - ctx = get_dist_manager().current_context() - - _check_group_device(device) + if tp_group is None: + return None if device == 'cpu': - return ctx.dp_cpu_group + return tp_group.cpu_group else: - return ctx.dp_gpu_group + return tp_group.gpu_group def get_group(group_type: str, device: str): """Get group.""" if group_type == 'tp': return get_tp_group(device) - elif group_type == 'dp': - return get_dp_group(device) elif group_type in ['world', 'all']: return get_process_group(device) else: diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e50843a80..abc06094a3 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -37,7 +37,6 @@ def __init__(self, self.dist_config = dist_config self.misc_config = misc_config self.dp = dist_config.dp - self.tp = dist_config.tp self.world_size = dist_config.world_size self.device_type = device_type @@ -163,7 +162,7 @@ def update_configs(self): logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb') vocal_size = self.model_config.vocab_size - tp = self.dist_config.attn_config.tp + tp = self.dist_config.attn_tp cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp, cache_config.quant_policy) runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size, vocal_size) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..c142a62610 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -199,8 +199,9 @@ def warmup_dist(self): from lmdeploy.pytorch.distributed import all_reduce, get_dist_manager with get_dist_manager().context(self.dist_ctx): + group = self.dist_ctx.tp_group.gpu_group tmp = torch.empty((1, ), device='cuda') - all_reduce(tmp) + all_reduce(tmp, group=group) def pack_output(self, output: Dict): """Pack output.""" @@ -242,14 +243,11 @@ def __init__(self, adapters=adapters, device_type=device_type) - self.dp_rank = dist_config.dp_rank device_ctx = DeviceContext(device_type) with get_device_manager().context(device_ctx): logger.info('Init ray cluster.') - ray_world_size = self.world_size - if self.dp > 1: - ray_world_size = 1 - self.ray_ctx = RayContext(ray_world_size, dp=dist_config.dp, device_type=device_type) + attn_tp = dist_config.attn_tp + self.ray_ctx = RayContext(attn_tp, dp=dist_config.dp, device_type=device_type) placement_group = self.ray_ctx.get_placement_group() self.placement_group = placement_group @@ -285,13 +283,11 @@ def __init__(self, self._init_distributed_environment_by_device(device_type) logger.info('Init distributed process group.') - if self.dp == 1: - ray.get([ - worker.init_process_group.remote(rank, self.master_addr, self.master_port) - for rank, worker in enumerate(self.workers) - ]) - else: - ray.get(self.workers[0].init_process_group.remote(self.dp_rank, self.master_addr, self.master_port)) + rank_offset = dist_config.dp_rank * attn_tp + ray.get([ + worker.init_process_group.remote(rank + rank_offset, self.master_addr, self.master_port) + for rank, worker in enumerate(self.workers) + ]) if self.dist_config.world_size > 1: logger.info('Warming up distribute environment, this might take long time, please waiting...') @@ -510,7 +506,8 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict for bundle_id, bundle in enumerate(placement_group.bundle_specs): if bundle.get(device_str, 0) and self._valid_bundle_id(bundle_id): bundle_indices.append(bundle_id) - bundle_indices = bundle_indices[:self.world_size] + attn_tp = self.dist_config.attn_tp + bundle_indices = bundle_indices[:attn_tp] workers = list() for _, bundle_id in enumerate(bundle_indices): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 7b332b6f0c..1ac733d6d6 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -118,7 +118,7 @@ def __init__(self, dist_ctx: DistContext, stream: torch.Stream): from lmdeploy.pytorch import envs self.rank = dist_ctx.rank self.dp_rank = dist_ctx.dp_rank - self.dp = dist_ctx.dp + self.dp = dist_ctx.dist_config.dp self.stream = stream self.profiler = None if self.dp == 1: @@ -335,6 +335,7 @@ def __init__(self, device = 'cuda' self.backend_config = backend_config self.misc_config = misc_config + self.dist_config = dist_ctx.dist_config rank = dist_ctx.rank self.model_path = model_path @@ -342,19 +343,17 @@ def __init__(self, self.device = device self.rank = rank - tp_rank = dist_ctx.tp_rank - tp = dist_ctx.tp - world_size = dist_ctx.world_size + tp = self.dist_config.tp + world_size = self.dist_config.world_size self.tp = tp self.world_size = world_size - self.tp_rank = tp_rank self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None # microbatch - self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch + self.enable_microbatch = self.dist_config.enable_microbatch self.enable_microbatch_prefill_batchsize_threshold = \ int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2)) self.enable_microbatch_prefill_token_threshold = \ @@ -395,12 +394,15 @@ def warmup(self): with self.all_context(): max_batches = self.cache_config.max_batches num_tokens = max_batches + dp = self.dist_config.dp # warmup prefill inputs = self.inputs_strategy.make_dummy(max_batches, is_decoding=False, device='cuda', vocab_size=self.model_config.vocab_size) + if dp > 1: + inputs.build_dp_meta() self._forward_impl(inputs) # warmup decoding(with cuda graph) @@ -411,6 +413,8 @@ def warmup(self): is_decoding=True, device='cuda', vocab_size=self.model_config.vocab_size) + if dp > 1: + inputs.build_dp_meta() self._forward_impl(inputs) def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): @@ -472,8 +476,8 @@ def get_output(self): async def __long_context_single_forward(new_inputs, max_seqlen: int): """One large sequence.""" - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp model_metas = new_inputs[0].model_metas output_gather = _OutputGather(max_seqlen) for inp in new_inputs: @@ -557,15 +561,15 @@ def _push_output(self, output: BatchedOutputs): self._out_que.put_nowait((output, event)) @contextmanager - def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True): + def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext, enable: bool = True): if not enable: yield return - if dist_ctx is None: - dist_ctx = get_dist_manager().current_context() - tp_gpu_group = dist_ctx.tp_gpu_group - handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + dist_ctx = get_dist_manager().current_context() + tp_gpu_group = dist_ctx.attn_tp_group.gpu_group + rank = dist.get_global_rank(tp_gpu_group, 0) + handle = dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True) yield handle.wait() @@ -604,6 +608,7 @@ async def __prepare_dp(): return nonlocal inputs, sync_long_context, is_all_dummy + world_size = self.dist_config.world_size # gather dp forward metadata batch_size = inputs.seq_length.numel() @@ -619,7 +624,7 @@ async def __prepare_dp(): self.enable_microbatch_prefill_batchsize_threshold and \ tokens_num >= self.enable_microbatch_prefill_token_threshold dp_forward_meta.append(int(enable_microbatch)) - gathered_meta = DistGatherScalar(dp_forward_meta, dp, device='cuda') + gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device='cuda') yield @@ -627,7 +632,7 @@ async def __prepare_dp(): # check is_decoding all_is_decoding = gathered_meta[:, 0] - assert all_is_decoding.sum().item() in [0, dp] + assert all_is_decoding.sum().item() in [0, world_size] # check if all inputs are dummy inputs is_all_dummy = gathered_meta[:, 1].all() @@ -655,9 +660,10 @@ async def __prepare_dp(): # dist tools dist_ctx = get_dist_manager().current_context() + dist_config = dist_ctx.dist_config rank = dist_ctx.rank - tp = dist_ctx.tp - dp = dist_ctx.dp + tp = dist_config.attn_tp + dp = dist_config.dp sync_long_context = False if dp == 1 else sync_long_context is_decoding = inputs.is_decoding @@ -698,7 +704,7 @@ async def __prepare_dp(): if is_dummy: continue - need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) + need_broadcast_next = (tp > 1 and idx < loop_count - 1) # sampling and stopping if need_output: @@ -754,8 +760,8 @@ async def __prepare_dp(): async def _async_loop_background(self, forward_event: asyncio.Event = None): """Async loop background.""" with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode(): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp # for dp if dp > 1: @@ -841,7 +847,7 @@ def start(self, forward_event: asyncio.Event = None): def stop(self): """Stop task.""" - if self.dist_ctx.dp > 1: + if self.dist_config.dp > 1: return if self.profiler is not None: @@ -857,7 +863,7 @@ def stop(self): async def stop_async(self): """Stop task.""" - if self.dist_ctx.dp > 1: + if self.dist_config.dp > 1: return if self.profiler is not None: @@ -947,14 +953,14 @@ def build_graph_runner(self): def build_cache_engine(self): """Build cache engine.""" with self.all_context(): - dist_ctx = self.dist_ctx - attn_dist_cfg = dist_ctx.dist_config.attn_config - tp = attn_dist_cfg.tp + dist_ctx = get_dist_manager().current_context() + dist_cfg = self.dist_config + tp = dist_cfg.attn_tp self.cache_engine = CacheEngine(self.cache_config, self.model_config, rank=self.rank, - tp_rank=self.tp_rank, + tp_rank=dist_ctx.attn_tp_group.rank, world_size=tp, cache_stream=self.cache_stream) @@ -1008,7 +1014,7 @@ def _construct(item): with self.all_context(): serialized_data = request.serialized_named_tensors if isinstance(serialized_data, list): - serialized_data = serialized_data[self.dist_ctx.tp_rank] + serialized_data = serialized_data[self.dist_ctx.tp_group.rank] weights = ForkingPickler.loads(base64.b64decode(serialized_data)) weights = [(k, _construct(v)) for k, v in weights] self.patched_model.get_model().load_weights(weights) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index a377c9d4d6..c8d8aa0c1c 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -19,22 +19,39 @@ @dataclass class DPMeta: tp_sizes: List[int] = None - ep_sizes: List[int] = None + moe_tp_sizes: List[int] = None + + @staticmethod + def _gather_tp_sizes(tp: int, attn_tp: int, seqlen: int, layer_type: str): + """Gather tp size.""" + if tp > 1 and tp != attn_tp: + tp_sizes = [None for _ in range(tp)] + tp_group = dist.get_tp_group('gpu', layer_type=layer_type) + dist.all_gather_object(tp_sizes, seqlen, group=tp_group) + else: + tp_sizes = [seqlen] + return tp_sizes @classmethod def build(cls, seqlen: int): """Get dp meta.""" - dist_ctx = dist.get_dist_manager().current_context() + dist_config = dist.get_dist_manager().current_config() + attn_tp = dist_config.attn_tp - tp = dist_ctx.tp - if tp > 1: - tp_sizes = [None for _ in range(tp)] - tp_group = dist.get_tp_group('gpu') - dist.all_gather_object(tp_sizes, seqlen, group=tp_group) + mlp_tp = dist_config.mlp_tp + tp_sizes = cls._gather_tp_sizes(mlp_tp, attn_tp, seqlen, layer_type='mlp') + + moe_tp = dist_config.moe_tp + if moe_tp == mlp_tp: + moe_tp_sizes = tp_sizes else: - tp_sizes = [seqlen] + moe_tp_sizes = cls._gather_tp_sizes(moe_tp, attn_tp, seqlen, layer_type='moe') + + return DPMeta(tp_sizes=tp_sizes, moe_tp_sizes=moe_tp_sizes) - return DPMeta(tp_sizes=tp_sizes, ) + def sync_tp_size(self, tp_size: int): + self.tp_sizes = [tp_size] * len(self.tp_sizes) + self.moe_tp_sizes = [tp_size] * len(self.moe_tp_sizes) @dataclass diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index a10e5da520..c4bc09e876 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -341,16 +341,9 @@ def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch self.dtype = dtype self.device = device - def _get_tp_world_rank(self): - """Get tp world rank.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp == 1: - return get_tp_world_rank() - return 1, 0 - def _update_batch(self, batch: int): """Update out features.""" - world_size, _ = self._get_tp_world_rank() + world_size, _ = get_tp_world_rank('attn') batch = batch // world_size return batch @@ -360,7 +353,7 @@ def create_weight(self, batch: int, in_features: int, out_features: int, dtype: def weight_loader(self, param: nn.Parameter, weight: torch.Tensor): """Weight loader.""" - world_size, rank = self._get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') weight = weight.chunk(world_size, 0)[rank] param.data.copy_(weight) @@ -521,11 +514,11 @@ def forward( attn_metadata: Any = None, ): """Rewrite of LlamaAttention.forward.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp > 1: + dist_config = get_dist_manager().current_config() + if dist_config.dp > 1: num_heads = self.num_heads else: - world_size = dist_ctx.world_size + world_size = dist_config.world_size num_heads = self.num_heads // world_size nope_size = self.kv_lora_rank q_len = hidden_states.size(1) @@ -678,9 +671,10 @@ def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: to self.topk_group = config.topk_group dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp - world_size = dist_ctx.world_size - moe_all_reduce = dp > 1 and dist_ctx.tp > 1 + dist_config = dist_ctx.dist_config + dp = dist_config.dp + world_size = dist_config.world_size + moe_all_reduce = dp > 1 and dist_config.tp > 1 if get_dist_manager().current_context().dist_config.enable_eplb: eplb_dispatch_info = EPLBManager.get_dispatch_info( ep_rank=dist_ctx.ep_rank, @@ -753,8 +747,8 @@ def __init__(self, super().__init__() quantization_config = getattr(config, 'quantization_config', None) if is_shared_expert: - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp if dp == 1: # split weight, do all reduce in moe is_tp = True diff --git a/lmdeploy/pytorch/models/llama4.py b/lmdeploy/pytorch/models/llama4.py index 4a84f35d85..28389f970d 100644 --- a/lmdeploy/pytorch/models/llama4.py +++ b/lmdeploy/pytorch/models/llama4.py @@ -205,8 +205,8 @@ def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: ) self.shared_expert = Llama4TextMLP(config, dtype=dtype, device=device, is_tp=True, all_reduce=False) - dist_ctx = dist.get_dist_manager().current_context() - self.tp = dist_ctx.tp + dist_config = dist.get_dist_manager().current_config() + self.tp = dist_config.tp def forward(self, hidden_states: torch.Tensor): """forward.""" diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index 464953f264..7818787f3e 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -6,7 +6,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank +from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.eplb import EPLBManager @@ -198,8 +198,6 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - world_size, _ = get_tp_world_rank() - _all_reduce = world_size > 1 if get_dist_manager().current_context().dist_config.enable_eplb: dist_ctx = get_dist_manager().current_context() self.eplb_dispatch_info = EPLBManager.get_dispatch_info( @@ -216,7 +214,7 @@ def __init__(self, dtype=dtype, device=device, quant_config=quantization_config, - all_reduce=_all_reduce, + all_reduce=True, layer_idx=layer_idx, ) diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py index 522d2aed95..190411b898 100644 --- a/lmdeploy/pytorch/models/sdar_moe.py +++ b/lmdeploy/pytorch/models/sdar_moe.py @@ -6,7 +6,6 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, @@ -180,8 +179,6 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - world_size, _ = get_tp_world_rank() - _all_reduce = world_size > 1 self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, @@ -191,7 +188,7 @@ def __init__(self, dtype=dtype, device=device, quant_config=quantization_config, - all_reduce=_all_reduce, + all_reduce=True, layer_idx=layer_idx, ) diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 7a1654db4b..52768591ff 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -2,7 +2,7 @@ import torch from torch import nn -from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank +from lmdeploy.pytorch.distributed import get_tp_world_rank from ..backends import OpType, get_backend from ..backends.attention import AttentionMetadata @@ -11,11 +11,7 @@ def _update_num_heads(num_heads: int, num_kv_heads: int): """Update heads.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp == 1: - world_size, rank = get_tp_world_rank() - else: - world_size, rank = 1, 0 + world_size, rank = get_tp_world_rank('attn') num_heads = get_distribute_size(num_heads, world_size, rank) num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) return num_heads, num_kv_heads diff --git a/lmdeploy/pytorch/nn/linear/__init__.py b/lmdeploy/pytorch/nn/linear/__init__.py index e57eaed9e5..3738e921c4 100644 --- a/lmdeploy/pytorch/nn/linear/__init__.py +++ b/lmdeploy/pytorch/nn/linear/__init__.py @@ -4,7 +4,8 @@ import torch from torch import nn -from lmdeploy.pytorch.distributed import get_dist_manager, get_dp_world_rank, get_tp_world_rank +from lmdeploy.pytorch.config import TPMode +from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank from .awq import AwqLinear, MergedAwqLinear, QKVAwqLinear from .blocked_fp8 import BlockedF8Linear, MergedBlockedF8Linear, QKVBlockedF8Linear @@ -13,51 +14,25 @@ from .w8a8 import MergedW8A8Linear, QKVW8A8Linear, W8A8Linear -def _is_dp_enabled(): - """Is dp.""" - return get_dp_world_rank()[0] > 1 - - -def _get_dp_gather(is_tp: bool): - """Get dp gather.""" - dp_gather = True - if not _is_dp_enabled(): - # disable if not dp - dp_gather = False - if not is_tp: - dp_gather = False - return dp_gather - - -def _get_dp_tp_meta(all_reduce: bool = True): - """Get tp meta.""" - dist_ctx = get_dist_manager().current_context() - dist_attn_cfg = dist_ctx.dist_config.attn_config - tp = dist_attn_cfg.tp - is_tp = tp > 1 - all_reduce = all_reduce if is_tp else False - return is_tp, all_reduce - - -def build_linear(in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False, - quant_config: Any = None, - all_reduce: bool = True, - tp_align_size: int = 1, - dp_gather: bool = False, - dp_scatter: bool = False) -> nn.Module: +def build_linear( + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + quant_config: Any = None, + all_reduce: bool = True, + tp_align_size: int = 1, + dp_gather: bool = False, + layer_type: str = 'attn', +) -> nn.Module: """Build linear.""" - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - if not is_tp: - all_reduce = False - - if (dp_scatter or dp_gather) and quant_config is not None: + if layer_type is None: + layer_type = 'attn' + all_reduce = all_reduce if is_tp else False + if dp_gather and quant_config is not None: quant_method = quant_config['quant_method'] assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}') @@ -73,7 +48,7 @@ def build_linear(in_features: int, all_reduce=all_reduce, tp_align_size=tp_align_size, dp_gather=dp_gather, - dp_scatter=dp_scatter, + layer_type=layer_type, ) quant_method = quant_config['quant_method'] @@ -94,6 +69,7 @@ def build_linear(in_features: int, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce, + layer_type=layer_type, ) if quant_method == 'smooth_quant': return W8A8Linear(in_features, @@ -104,7 +80,8 @@ def build_linear(in_features: int, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) elif quant_method == 'fp8': fmt = quant_config.get('fmt', 'e4m3') if fmt == 'e4m3': @@ -124,7 +101,7 @@ def build_linear(in_features: int, is_tp=is_tp, all_reduce=all_reduce, dp_gather=dp_gather, - dp_scatter=dp_scatter, + layer_type=layer_type, ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -139,16 +116,20 @@ def build_colwise_linear(in_features: int, tp_align_size: int = 1, quant_config: Any = None, dp_disable_tp: bool = False, - dp_gather: bool = False) -> nn.Module: + dp_gather: bool = False, + check_dist: bool = True, + layer_type: str = 'attn') -> nn.Module: """Build columnwise parallel linear layer.""" - if dp_disable_tp and is_tp: - is_tp, _ = _get_dp_tp_meta() - elif is_tp: - is_tp = get_tp_world_rank()[0] > 1 + if check_dist: + dist_config = get_dist_manager().current_config() + tp, tp_mode = dist_config.get_tp_by_layer(layer_type) + + # check is_tp + is_tp = is_tp if tp > 1 else False + is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp - if dp_gather: - assert not dp_disable_tp - dp_gather = _get_dp_gather(is_tp) + # check dp_gather + dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False return build_linear(in_features=in_features, out_features=out_features, @@ -160,7 +141,8 @@ def build_colwise_linear(in_features: int, quant_config=quant_config, all_reduce=False, tp_align_size=tp_align_size, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) def build_rowwise_linear(in_features: int, @@ -173,10 +155,14 @@ def build_rowwise_linear(in_features: int, quant_config: Any = None, all_reduce: bool = True, dp_disable_tp: bool = False, - dp_scatter: bool = False) -> nn.Module: + check_dist: bool = True, + layer_type: str = 'attn') -> nn.Module: """Build rowwise parallel linear layer.""" - if dp_disable_tp and is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + if check_dist: + dist_config = get_dist_manager().current_config() + tp, _ = dist_config.get_tp_by_layer(layer_type) + is_tp = is_tp if tp > 1 else False + is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp return build_linear( in_features=in_features, out_features=out_features, @@ -188,7 +174,7 @@ def build_rowwise_linear(in_features: int, quant_config=quant_config, all_reduce=all_reduce, tp_align_size=tp_align_size, - dp_scatter=dp_scatter, + layer_type=layer_type, ) @@ -202,10 +188,12 @@ def build_merged_colwise_linear( is_tp: bool = True, out_names: List[Any] = None, dp_gather: bool = False, + check_dist: bool = True, + layer_type: str = 'attn', ): """Merge linear.""" - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 + if check_dist and is_tp: + is_tp = get_tp_world_rank(layer_type)[0] > 1 if dp_gather and quant_config is not None: quant_method = quant_config['quant_method'] @@ -219,7 +207,8 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, out_names=out_names, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) quant_method = quant_config['quant_method'] quant_dtype = torch.int8 @@ -237,6 +226,7 @@ def build_merged_colwise_linear( bias=bias, device=device, is_tp=is_tp, + layer_type=layer_type, ) if quant_method == 'smooth_quant': return MergedW8A8Linear(in_features=in_features, @@ -246,7 +236,8 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, out_names=out_names, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) elif quant_method == 'fp8': fmt = quant_config.get('fmt', 'e4m3') if fmt == 'e4m3': @@ -265,6 +256,7 @@ def build_merged_colwise_linear( is_tp=is_tp, out_names=out_names, dp_gather=dp_gather, + layer_type=layer_type, ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -280,19 +272,10 @@ def build_qkv_proj(in_features: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - num_replicate_kv_heads: int = 1, - dp_disable_tp: bool = True, - all_reduce: bool = False, - dp_gather: bool = False): + num_replicate_kv_heads: int = 1): """Build qkv proj.""" - if dp_disable_tp and is_tp: - is_tp, _ = _get_dp_tp_meta(all_reduce) - elif is_tp: - is_tp = get_tp_world_rank()[0] > 1 - - if dp_gather: - assert not dp_disable_tp - dp_gather = _get_dp_gather(is_tp) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.attn_tp > 1 else False if head_size_v is None: head_size_v = head_size @@ -358,7 +341,7 @@ def build_qkv_proj(in_features: int, dtype=dtype, device=device, is_tp=is_tp, - dp_gather=dp_gather, + dp_gather=False, num_replicate_kv_heads=num_replicate_kv_heads) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -374,8 +357,8 @@ def build_o_proj(in_features: int, quant_config: Any = None, all_reduce: bool = True) -> nn.Module: """Build down linear.""" - if is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.attn_tp > 1 else False return build_rowwise_linear( in_features=in_features, @@ -387,6 +370,8 @@ def build_o_proj(in_features: int, tp_align_size=tp_align_size, quant_config=quant_config, all_reduce=all_reduce, + check_dist=False, + layer_type='attn', ) @@ -400,12 +385,10 @@ def build_gateup_linear(in_features: int, out_names: List[Any] = None, dp_gather: bool = True): """Build gate up linear.""" - if dp_gather: - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - dp_gather = _get_dp_gather(is_tp) - elif is_tp: - is_tp, _ = _get_dp_tp_meta() + dist_config = get_dist_manager().current_config() + tp, tp_mode = dist_config.get_tp_by_layer('mlp') + is_tp = is_tp if tp > 1 else False + dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False return build_merged_colwise_linear( in_features=in_features, @@ -417,6 +400,8 @@ def build_gateup_linear(in_features: int, is_tp=is_tp, out_names=out_names, dp_gather=dp_gather, + check_dist=False, + layer_type='mlp', ) @@ -428,19 +413,10 @@ def build_down_linear(in_features: int, is_tp: bool = False, tp_align_size: int = 1, quant_config: Any = None, - all_reduce: bool = True, - dp_scatter: bool = True) -> nn.Module: + all_reduce: bool = True) -> nn.Module: """Build down linear.""" - if dp_scatter: - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - if not _is_dp_enabled(): - # disable if not dp - dp_scatter = False - if not is_tp: - dp_scatter = False - elif is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.mlp_tp > 1 else False return build_rowwise_linear( in_features=in_features, @@ -452,5 +428,6 @@ def build_down_linear(in_features: int, tp_align_size=tp_align_size, quant_config=quant_config, all_reduce=all_reduce, - dp_scatter=dp_scatter, + check_dist=False, + layer_type='mlp', ) diff --git a/lmdeploy/pytorch/nn/linear/awq.py b/lmdeploy/pytorch/nn/linear/awq.py index ede2cea60c..5e24d93db7 100644 --- a/lmdeploy/pytorch/nn/linear/awq.py +++ b/lmdeploy/pytorch/nn/linear/awq.py @@ -8,7 +8,7 @@ from ..utils import chunk_aligned, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class AwqLinear(LinearBase): @@ -25,8 +25,14 @@ def __init__( colwise: bool = True, is_tp: bool = False, all_reduce: bool = True, + layer_type: str = 'attn', ): - super().__init__(dtype=torch.float16, device=device, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce) + super().__init__(dtype=torch.float16, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, w_bit, group_size, colwise) qweight, scales, qzeros, bias = self.create_weights(in_features, out_features, w_bit, group_size, bias, @@ -78,7 +84,7 @@ def register_all_parameters(self, def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool): """Get io features.""" align = max(32 // w_bit, group_size) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank, align=align) else: @@ -128,7 +134,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -159,7 +165,7 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce) + return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce, group=self.tp_group) class MergedAwqLinear(AwqLinear): @@ -173,10 +179,11 @@ def __init__(self, bias: bool, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: Optional[List[int]] = None): + out_names: Optional[List[int]] = None, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section_s = all_out_features - self.is_tp = is_tp elem_per_int = 32 // w_bit self.split_section_wz = [size // elem_per_int for size in all_out_features] @@ -187,7 +194,15 @@ def __init__(self, assert len(out_names) == len(self.all_out_features) self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names)) out_features = sum(all_out_features) - super().__init__(in_features, out_features, w_bit, group_size, bias, device, colwise=True, is_tp=is_tp) + super().__init__(in_features, + out_features, + w_bit, + group_size, + bias, + device, + colwise=True, + is_tp=is_tp, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -212,7 +227,7 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, grou def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] align = max(32 // w_bit, group_size) for out_feat in all_out_features: @@ -222,7 +237,7 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, grou def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] if loaded_weight.dim() == 1: # bias @@ -268,13 +283,16 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, num_replicate_kv_heads: int = 1): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) elem_per_int = 32 // w_bit self.qkv_split_section_s = self.qkv_split_section @@ -288,7 +306,8 @@ def __init__(self, bias=bias, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int): """Update all out features.""" @@ -296,7 +315,7 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, grou def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] diff --git a/lmdeploy/pytorch/nn/linear/base.py b/lmdeploy/pytorch/nn/linear/base.py index e6711f271e..4b2ab2c686 100644 --- a/lmdeploy/pytorch/nn/linear/base.py +++ b/lmdeploy/pytorch/nn/linear/base.py @@ -5,32 +5,30 @@ import torch.distributed as dist from torch import nn -from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.config import TPMode +from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_group, get_tp_world_rank from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from .utils import update_tp_args -def _gather_input(x: torch.Tensor, tp_sizes: List[int]): +def _gather_input(x: torch.Tensor, tp_sizes: List[int], group: dist.ProcessGroup): """Gather input.""" shape0 = x.shape[:-2] shape1 = x.shape[-1:] shapes = [shape0 + (size, ) + shape1 for size in tp_sizes] new_x = [x.new_empty(shape) for shape in shapes] - dist.all_gather(new_x, x) + dist.all_gather(new_x, x, group=group) x = torch.cat(new_x, dim=-2) return x -def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): +def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup): """Reduce scatter.""" - _, rank = get_tp_world_rank() - out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() + out = out.transpose(0, -2).contiguous() outs = out.split(tp_sizes, 0) out = outs[rank] - dist.reduce_scatter(out, outs) + dist.reduce_scatter(out, outs, group=group) out = out.transpose(0, -2) return out @@ -47,25 +45,50 @@ def __init__( all_reduce: bool = True, tp_align_size: int = 1, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__() - is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise) + self.init_tp_args(is_tp, all_reduce, colwise, layer_type) self.colwise = colwise - self.is_tp = is_tp - self.all_reduce = all_reduce self.tp_align_size = tp_align_size self.dp_gather = dp_gather - self.dp_scatter = dp_scatter if device is None: device = torch.device('cpu') if dtype is None: dtype = torch.float16 self.device = device self.dtype = dtype + self.layer_type = layer_type self.lora_adapters = nn.ModuleDict() + def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str): + if getattr(self, '_tp_args_initialized', False): + return + is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise, layer_type=layer_type) + self.is_tp = is_tp + self.all_reduce = all_reduce + if is_tp: + dist_cfg = get_dist_manager().current_config() + _, rank = get_tp_world_rank(layer_type) + tp, tp_mode = dist_cfg.get_tp_by_layer(layer_type) + self.tp_rank = rank + self.tp = tp + self.tp_mode = tp_mode + self.tp_group = get_tp_group(layer_type=layer_type) + else: + self.tp_rank = 0 + self.tp = 1 + self.tp_mode = TPMode.DEFAULT + self.tp_group = None + + self._tp_args_initialized = True + + def get_tp_world_rank(self): + """Get tp world rank.""" + assert hasattr(self, 'tp') and hasattr(self, 'tp_rank'), 'Please run init_tp_args first.' + return self.tp, self.tp_rank + def update_weights(self): """Update weights.""" raise NotImplementedError('This method should be implemented in subclasses.') @@ -81,22 +104,22 @@ def _forward_lora(self, x, tp_sizes: List[int]): for lora_adapter in self.lora_adapters.values(): out = lora_adapter(x, out) if self.all_reduce: - if self.dp_scatter: - out = _reduce_scatter_input(out, tp_sizes) + if self.tp_mode == TPMode.DP_TP: + out = _reduce_scatter_input(out, self.tp_rank, tp_sizes, group=self.tp_group) else: - dist.all_reduce(out) + dist.all_reduce(out, group=self.tp_group) return out def forward(self, x): """Forward of linear layer.""" tp_sizes = None - if self.dp_gather or self.dp_scatter: + if self.dp_gather or (self.all_reduce and self.tp_mode == TPMode.DP_TP): step_ctx = get_step_ctx_manager().current_context() dp_meta = step_ctx.dp_meta tp_sizes = dp_meta.tp_sizes if self.dp_gather: - x = _gather_input(x, tp_sizes) + x = _gather_input(x, tp_sizes, group=self.tp_group) if len(self.lora_adapters) == 0: return self._forward_default(x, self.all_reduce, tp_sizes) diff --git a/lmdeploy/pytorch/nn/linear/blocked_fp8.py b/lmdeploy/pytorch/nn/linear/blocked_fp8.py index 0638bf7650..163d4ee1f4 100644 --- a/lmdeploy/pytorch/nn/linear/blocked_fp8.py +++ b/lmdeploy/pytorch/nn/linear/blocked_fp8.py @@ -4,13 +4,12 @@ import torch from lmdeploy.pytorch.backends import OpType, get_backend -from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader from ..quant_utils import quant_blocked_fp8 from ..utils import div_up, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class BlockedF8Linear(LinearBase): @@ -28,7 +27,7 @@ def __init__( is_tp: bool = False, all_reduce: bool = True, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__(dtype=dtype, device=device, @@ -36,7 +35,7 @@ def __init__( is_tp=is_tp, all_reduce=all_reduce, dp_gather=dp_gather, - dp_scatter=dp_scatter) + layer_type=layer_type) self.block_size = 128 self.fp8_dtype = fp8_dtype if self.is_tp: @@ -76,7 +75,7 @@ def register_all_parameters(self, def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = get_tp_world_rank() + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank) else: @@ -107,7 +106,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -142,11 +141,18 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - if self.dp_scatter: - _, rank = get_tp_world_rank() - return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, rank, tp_sizes) + if self.tp_mode: + rank = self.tp_rank + return self.impl.forward(x, + self.weight, + self.weight_scale_inv, + self.bias, + all_reduce, + group=self.tp_group, + rank=rank, + scatter_size=tp_sizes) else: - return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, group=self.tp_group) class MergedBlockedF8Linear(BlockedF8Linear): @@ -162,12 +168,13 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - dp_gather: bool = False): + dp_gather: bool = False, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) if replicate is None: replicate = tuple(False for _ in all_out_features) self.block_size = 128 self.split_section = all_out_features - self.is_tp = is_tp self.scale_split_section = [section // self.block_size for section in self.split_section] all_out_features = self._update_all_out_features(all_out_features, replicate) self.all_out_features = all_out_features @@ -185,7 +192,8 @@ def __init__(self, fp8_dtype=fp8_dtype, colwise=True, is_tp=is_tp, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -207,7 +215,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat, rep in zip(all_out_features, replicate): if rep: @@ -218,7 +226,7 @@ def _update_all_out_features(self, all_out_features: List[int], replicate: Optio def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype: loaded_weight = loaded_weight.to(torch.float32) @@ -267,13 +275,16 @@ def __init__(self, dp_gather: bool = False, num_replicate_kv_heads: int = 1): self.block_size = 128 + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -285,7 +296,8 @@ def __init__(self, device=device, is_tp=is_tp, out_names=out_names, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]): """Update all out features.""" @@ -293,7 +305,7 @@ def _update_all_out_features(self, all_out_features: List[int], replicate: Optio def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - _, rank = _get_tp_world_rank(self.is_tp) + _, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] num_head = self.num_q_heads if shard_id == 'q' \ diff --git a/lmdeploy/pytorch/nn/linear/default.py b/lmdeploy/pytorch/nn/linear/default.py index a2556710a6..a3f8a31a2c 100644 --- a/lmdeploy/pytorch/nn/linear/default.py +++ b/lmdeploy/pytorch/nn/linear/default.py @@ -4,12 +4,12 @@ import torch from lmdeploy.pytorch.backends import OpType, get_backend -from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader from ..utils import chunk_aligned, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class BaseLinear(LinearBase): @@ -27,7 +27,7 @@ def __init__( all_reduce: bool = True, tp_align_size: int = 1, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__(dtype=dtype, device=device, @@ -36,7 +36,7 @@ def __init__( all_reduce=all_reduce, tp_align_size=tp_align_size, dp_gather=dp_gather, - dp_scatter=dp_scatter) + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.Linear) @@ -64,7 +64,7 @@ def register_all_parameters(self, weight: torch.Tensor, bias: Optional[torch.Ten def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank, align=self.tp_align_size) else: @@ -95,7 +95,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -117,11 +117,17 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - if self.dp_scatter: - _, rank = get_tp_world_rank() - return self.impl.forward(x, self.weight, self.bias, all_reduce, rank, tp_sizes) + if self.tp_mode == TPMode.DP_TP: + rank = self.tp_rank + return self.impl.forward(x, + self.weight, + self.bias, + all_reduce, + group=self.tp_group, + rank=rank, + scatter_size=tp_sizes) else: - return self.impl.forward(x, self.weight, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.bias, all_reduce, group=self.tp_group) class MergedBaseLinear(BaseLinear): @@ -135,9 +141,10 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - dp_gather: bool = False): + dp_gather: bool = False, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section = all_out_features - self.is_tp = is_tp all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features if out_names is None: @@ -145,7 +152,15 @@ def __init__(self, assert len(out_names) == len(self.all_out_features) self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names)) out_features = sum(all_out_features) - super().__init__(in_features, out_features, bias, dtype, device, colwise=True, is_tp=is_tp, dp_gather=dp_gather) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + colwise=True, + is_tp=is_tp, + dp_gather=dp_gather, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -162,7 +177,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) @@ -171,7 +186,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] loaded_weight = loaded_weight.chunk(world_size, 0)[rank] @@ -199,13 +214,16 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, num_replicate_kv_heads: int = 1): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -215,7 +233,8 @@ def __init__(self, dtype=dtype, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" @@ -223,7 +242,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] diff --git a/lmdeploy/pytorch/nn/linear/utils.py b/lmdeploy/pytorch/nn/linear/utils.py index 433a26c720..3e8c7db6f6 100644 --- a/lmdeploy/pytorch/nn/linear/utils.py +++ b/lmdeploy/pytorch/nn/linear/utils.py @@ -17,18 +17,10 @@ def check_qkv_split_layout(layout: str): f'but get: {layout}') -def _get_tp_world_rank(is_tp: bool): - """Get tp world size.""" - if is_tp: - return get_tp_world_rank() - else: - return 1, 0 - - -def update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool): +def update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str = 'attn'): """Update tp args according to the environment.""" if is_tp: - world, _ = get_tp_world_rank() + world, _ = get_tp_world_rank(layer_type) is_tp = world > 1 if not is_tp or colwise: @@ -46,10 +38,12 @@ def __init__(self, head_size: int, head_size_v: int, num_replicate_kv_heads: int = 1, - is_tp: bool = False): + is_tp: bool = False, + tp: int = 1, + tp_rank: int = 0): qkv_split_section = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v, num_replicate_kv_heads) - num_q_heads, num_kv_heads = self._update_num_heads(is_tp, num_q_heads, num_kv_heads) + num_q_heads, num_kv_heads = self._update_num_heads(is_tp, tp, tp_rank, num_q_heads, num_kv_heads) self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size @@ -72,11 +66,11 @@ def _get_qkv_out_features(self, all_out_features = (num_q_heads * head_size, num_kv_heads_real * head_size, num_kv_heads_real * head_size_v) return all_out_features - def _update_num_heads(self, is_tp: bool, num_q_heads: int, num_kv_heads: int): + def _update_num_heads(self, is_tp: bool, tp: int, tp_rank: int, num_q_heads: int, num_kv_heads: int): """Update num heads.""" if not is_tp: return num_q_heads, num_kv_heads - world_size, rank = get_tp_world_rank() + world_size, rank = tp, tp_rank num_q_heads = get_distribute_size(num_q_heads, world_size, rank) num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) diff --git a/lmdeploy/pytorch/nn/linear/w8a8.py b/lmdeploy/pytorch/nn/linear/w8a8.py index d5d0d476f7..c9105e5599 100644 --- a/lmdeploy/pytorch/nn/linear/w8a8.py +++ b/lmdeploy/pytorch/nn/linear/w8a8.py @@ -8,7 +8,7 @@ from ..utils import get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class W8A8Linear(LinearBase): @@ -23,8 +23,14 @@ def __init__(self, colwise: bool = True, is_tp: bool = False, all_reduce: bool = True, - quant_dtype: Optional[torch.dtype] = torch.int8): - super().__init__(dtype=torch.float16, device=device, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce) + quant_dtype: Optional[torch.dtype] = torch.int8, + layer_type: str = 'attn'): + super().__init__(dtype=torch.float16, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8) @@ -60,7 +66,7 @@ def register_all_parameters(self, weight: torch.Tensor, scale: torch.Tensor, bia def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank) else: @@ -94,7 +100,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -117,7 +123,7 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce, group=self.tp_group) class MergedW8A8Linear(W8A8Linear): @@ -131,9 +137,10 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - quant_dtype: torch.dtype = torch.int8): + quant_dtype: torch.dtype = torch.int8, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section = all_out_features - self.is_tp = is_tp all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features if out_names is None: @@ -148,7 +155,8 @@ def __init__(self, device, colwise=True, is_tp=is_tp, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -167,7 +175,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) @@ -176,7 +184,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] loaded_weight = loaded_weight.chunk(world_size, 0)[rank] @@ -205,13 +213,16 @@ def __init__(self, is_tp: bool = True, num_replicate_kv_heads: int = 1, quant_dtype: torch.dtype = torch.int8): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -222,7 +233,8 @@ def __init__(self, device=device, is_tp=is_tp, out_names=out_names, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" @@ -230,7 +242,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - _, rank = _get_tp_world_rank(self.is_tp) + _, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] num_head = self.num_q_heads if shard_id == 'q' \ diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 0f88cdf851..75e510927b 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -47,7 +47,7 @@ def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int, dtype: t def _update_args(hidden_dim: int, ffn_dim: int): """Update args.""" - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') assert ffn_dim % world_size == 0 ffn_dim = ffn_dim // world_size return hidden_dim, ffn_dim @@ -107,7 +107,7 @@ def update_weight(self, weight: torch.Tensor): def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') if shard_id == 'gate': param_data = param.data[expert_id, :self.half_out] weight = loaded_weight.chunk(world_size, dim=0)[rank] @@ -160,7 +160,7 @@ def _gather_input(x: torch.Tensor, tp_sizes: List[int]): def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): """Reduce scatter.""" - _, rank = get_tp_world_rank() + _, rank = get_tp_world_rank('moe') out = out.transpose(0, -2) if not out.is_contiguous(): out = out.contiguous() @@ -173,15 +173,15 @@ def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, enable_ep): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp if dp <= 1: return hidden_states, topk_weights, topk_ids step_ctx = get_step_ctx_manager().current_context() dp_meta = step_ctx.dp_meta if not enable_ep: - if dist_ctx.tp == 1: + if dist_config.tp == 1: return hidden_states, topk_weights, topk_ids tp_sizes = dp_meta.tp_sizes hidden_states = _gather_input(hidden_states, tp_sizes) @@ -193,13 +193,13 @@ def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, enable_ep): def _moe_reduce(ret, enable_ep): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp if dp > 1: step_ctx = get_step_ctx_manager().current_context() dp_meta = step_ctx.dp_meta if not enable_ep: - if dist_ctx.tp == 1: + if dist_config.tp == 1: return ret tp_sizes = dp_meta.tp_sizes ret = _reduce_scatter_input(ret, tp_sizes) @@ -236,7 +236,7 @@ def __init__(self, enable_ep = enable_ep and self.impl.support_ep() if enable_ep: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') expert_list = self.impl.ep_expert_list(world_size, rank) num_experts = len(expert_list) else: @@ -269,7 +269,7 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') if world_size == 1: all_reduce = False self.all_reduce = all_reduce @@ -342,7 +342,7 @@ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') if shard_id == 'gate': param_data = param.data[expert_id, :self.half_out] weight = loaded_weight.chunk(world_size, dim=0)[rank] @@ -383,7 +383,7 @@ def __init__(self, enable_ep = enable_ep and self.impl.support_ep() if enable_ep: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') expert_list = self.impl.ep_expert_list(world_size, rank) num_experts = len(expert_list) else: @@ -413,7 +413,7 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') if world_size == 1: all_reduce = False self.all_reduce = all_reduce @@ -493,7 +493,7 @@ def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') block_size = self.block_size half_out = self.half_out // block_size if shard_id == 'gate': @@ -594,7 +594,7 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') if world_size == 1: all_reduce = False self.all_reduce = all_reduce diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index 908624b3c8..f477467f0d 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -45,7 +45,7 @@ def __init__(self, builder = backend.get_layer_impl_builder(OpType.RMSNorm) if tp: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') hidden_size = get_distribute_size(hidden_size, world_size, rank, align=align) self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) @@ -60,7 +60,7 @@ def __init__(self, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): """Weight loader.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') loaded_weight = chunk_aligned(loaded_weight, world_size, 0, self.align)[rank] param.copy_(loaded_weight) From 5269b62f2394bdb638739c08af44a87d2b4caa6c Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 23 Sep 2025 17:33:26 +0800 Subject: [PATCH 02/15] moe --- .../backends/cuda/blockedf8_modules.py | 17 +-- lmdeploy/pytorch/backends/default/linear.py | 15 +- lmdeploy/pytorch/check_env/dist.py | 4 +- lmdeploy/pytorch/config.py | 4 +- lmdeploy/pytorch/distributed.py | 19 +++ lmdeploy/pytorch/nn/linear/base.py | 28 +--- lmdeploy/pytorch/nn/moe.py | 139 ++++++++++-------- lmdeploy/serve/openai/launch_server.py | 5 +- 8 files changed, 111 insertions(+), 120 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py index bb70a9f718..e25e9c52e7 100644 --- a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py +++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py @@ -13,15 +13,6 @@ logger = get_logger('lmdeploy') -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: Optional[dist.ProcessGroup] = None): - """Reduce scatter.""" - outs = out.split(tp_sizes, -2) - out = outs[rank] - outs = list(outs) - dist.reduce_scatter(out, outs, group=group) - return out - - class TritonLinearBlockedF8Impl(LinearBlockedF8Impl): """Triton linear blocked f8 implementation.""" @@ -37,6 +28,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" @@ -52,7 +44,7 @@ def forward(self, if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: dist.all_reduce(out) return out @@ -129,12 +121,11 @@ def forward(self, out = out[:x.size(0)] if bias is not None: out += bias + out = out.unflatten(0, x_shape[:-1]) if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size, group=group) + out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: dist.all_reduce(out, group=group) - - out = out.unflatten(0, x_shape[:-1]) return out diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index 6e4870b0e5..f766123fff 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -8,18 +8,6 @@ from ..linear import LinearBuilder, LinearImpl -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup = None): - """Reduce scatter.""" - out = out.transpose(0, -2) - out = out.contiguous() - outs = out.split(tp_sizes, 0) - out = outs[rank] - outs = list(outs) - dist.reduce_scatter(out, outs, group=group) - out = out.transpose(0, -2) - return out - - class DefaultLinearImpl(LinearImpl): """Linear implementation api.""" @@ -35,7 +23,8 @@ def forward(self, out = F.linear(x, weight, bias) if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size, group=group) + from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes + out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/check_env/dist.py b/lmdeploy/pytorch/check_env/dist.py index e31e6c9161..eb6e48e882 100644 --- a/lmdeploy/pytorch/check_env/dist.py +++ b/lmdeploy/pytorch/check_env/dist.py @@ -44,9 +44,9 @@ def check(self): if self.device_type == 'cuda' and not is_dlblas_installed(): self.log_and_exit(mod_name='Dist', message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).') - if self.dp % self.ep != 0: + if self.ep % self.dp != 0: self.log_and_exit(mod_name='Dist', - message=f'ep>1 requires dp % ep == 0. Get dp={self.dp} and ep={self.ep}.') + message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.') elif self.dist_config.enable_eplb: self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.') diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 7e73ed7fb8..4124783987 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -148,8 +148,8 @@ def __post_init__(self): self.tp = self.attn_tp # tp mode - self.mlp_tp_mode = TPMode.DEFAULT if self.attn_tp == self.mlp_tp else TPMode.DP_TP - self.moe_tp_mode = TPMode.DEFAULT if self.attn_tp == self.moe_tp else TPMode.DP_TP + self.mlp_tp_mode = TPMode.DEFAULT if (self.mlp_tp in [1, self.attn_tp]) else TPMode.DP_TP + self.moe_tp_mode = TPMode.DEFAULT if (self.moe_tp in [1, self.attn_tp]) else TPMode.DP_TP def get_tp_by_layer(self, layer_type: str): """Get tp by layer type.""" diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index b17a44a802..05f9391c3d 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -5,6 +5,7 @@ from datetime import timedelta from typing import List, Optional +import torch from torch import distributed as dist from torch.distributed import ProcessGroup, ReduceOp # noqa: F401 @@ -400,3 +401,21 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group='tp', async_op=Fal if isinstance(group, str): group = get_group(group, 'gpu') return dist.reduce_scatter(output, input_list, op=op, group=group, async_op=async_op) + + +def gather_by_tp_sizes(x: torch.Tensor, tp_sizes: List[int], group: Optional[dist.ProcessGroup] = None): + """Gather input.""" + shape = (*x.shape[:-2], sum(tp_sizes), *x.shape[-1:]) + new_x = x.new_empty(shape) + split_new_x = list(new_x.split(tp_sizes, -2)) + dist.all_gather(split_new_x, x, group=group) + return new_x + + +def reduce_scatter_by_tp_sizes(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup): + """Reduce scatter.""" + outs = out.split(tp_sizes, -2) + out = outs[rank] + outs = list(outs) + dist.reduce_scatter(out, outs, group=group) + return out diff --git a/lmdeploy/pytorch/nn/linear/base.py b/lmdeploy/pytorch/nn/linear/base.py index 4b2ab2c686..90057a920b 100644 --- a/lmdeploy/pytorch/nn/linear/base.py +++ b/lmdeploy/pytorch/nn/linear/base.py @@ -6,33 +6,13 @@ from torch import nn from lmdeploy.pytorch.config import TPMode -from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_group, get_tp_world_rank +from lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_manager, get_tp_group, get_tp_world_rank, + reduce_scatter_by_tp_sizes) from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from .utils import update_tp_args -def _gather_input(x: torch.Tensor, tp_sizes: List[int], group: dist.ProcessGroup): - """Gather input.""" - shape0 = x.shape[:-2] - shape1 = x.shape[-1:] - shapes = [shape0 + (size, ) + shape1 for size in tp_sizes] - new_x = [x.new_empty(shape) for shape in shapes] - dist.all_gather(new_x, x, group=group) - x = torch.cat(new_x, dim=-2) - return x - - -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup): - """Reduce scatter.""" - out = out.transpose(0, -2).contiguous() - outs = out.split(tp_sizes, 0) - out = outs[rank] - dist.reduce_scatter(out, outs, group=group) - out = out.transpose(0, -2) - return out - - class LinearBase(nn.Module): """Base class for linear layers.""" @@ -105,7 +85,7 @@ def _forward_lora(self, x, tp_sizes: List[int]): out = lora_adapter(x, out) if self.all_reduce: if self.tp_mode == TPMode.DP_TP: - out = _reduce_scatter_input(out, self.tp_rank, tp_sizes, group=self.tp_group) + out = reduce_scatter_by_tp_sizes(out, self.tp_rank, tp_sizes, group=self.tp_group) else: dist.all_reduce(out, group=self.tp_group) return out @@ -119,7 +99,7 @@ def forward(self, x): tp_sizes = dp_meta.tp_sizes if self.dp_gather: - x = _gather_input(x, tp_sizes, group=self.tp_group) + x = gather_by_tp_sizes(x, tp_sizes, group=self.tp_group) if len(self.lora_adapters) == 0: return self._forward_default(x, self.all_reduce, tp_sizes) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 75e510927b..13529db12c 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -7,6 +7,7 @@ from torch import nn import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank from lmdeploy.pytorch.model_inputs import get_step_ctx_manager @@ -147,67 +148,44 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso param_data.copy_(loaded_weight) -def _gather_input(x: torch.Tensor, tp_sizes: List[int]): - """Gather input.""" - shape0 = x.shape[:-2] - shape1 = x.shape[-1:] - shapes = [shape0 + (size, ) + shape1 for size in tp_sizes] - new_x = [x.new_empty(shape) for shape in shapes] - dist.all_gather(new_x, x) - x = torch.cat(new_x, dim=-2) - return x - - -def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): - """Reduce scatter.""" - _, rank = get_tp_world_rank('moe') - out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() - outs = out.split(tp_sizes, 0) - outs = list(outs) - out = outs[rank] - dist.reduce_scatter(out, outs) - out = out.transpose(0, -2) - return out - - -def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, enable_ep): +def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None): dist_config = get_dist_manager().current_config() - dp = dist_config.dp - if dp <= 1: + tp = dist_config.moe_tp + if tp == 1: return hidden_states, topk_weights, topk_ids - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - if not enable_ep: - if dist_config.tp == 1: - return hidden_states, topk_weights, topk_ids - tp_sizes = dp_meta.tp_sizes - hidden_states = _gather_input(hidden_states, tp_sizes) - topk_weights = _gather_input(topk_weights, tp_sizes) - topk_ids = _gather_input(topk_ids, tp_sizes) + tp_mode = dist_config.moe_tp_mode + if tp_mode == TPMode.DEFAULT: + return hidden_states, topk_weights, topk_ids + elif tp_mode == TPMode.DP_TP: + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_sizes = dp_meta.moe_tp_sizes + hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group) + topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group) + topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group) else: raise RuntimeError('Not supported.') + return hidden_states, topk_weights, topk_ids -def _moe_reduce(ret, enable_ep): +def _moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None): dist_config = get_dist_manager().current_config() - dp = dist_config.dp - if dp > 1: + if dist_config.moe_tp == 1: + return ret + + if tp_mode == TPMode.DEFAULT: + dist.all_reduce(ret, group=group) + return ret + elif tp_mode == TPMode.DP_TP: step_ctx = get_step_ctx_manager().current_context() dp_meta = step_ctx.dp_meta - if not enable_ep: - if dist_config.tp == 1: - return ret - tp_sizes = dp_meta.tp_sizes - ret = _reduce_scatter_input(ret, tp_sizes) - else: - raise RuntimeError('Not supported.') + tp_size = dp_meta.moe_tp_sizes + ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group) + return ret else: - dist.all_reduce(ret) - return ret + raise RuntimeError('Not supported.') class FusedMoE(nn.Module): @@ -230,6 +208,7 @@ def __init__(self, device = torch.device('cpu') if dtype is None: dtype = torch.float16 + self.init_tp_args(all_reduce, enable_ep) impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) self.impl = impl_builder.build(top_k, num_experts, renormalize) @@ -269,13 +248,26 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank('moe') - if world_size == 1: - all_reduce = False - self.all_reduce = all_reduce self.enable_ep = enable_ep self.act_func = act_func + def init_tp_args(self, all_reduce: bool, enable_ep: bool): + """Init tp args.""" + tp, tp_rank = get_tp_world_rank('moe') + dist_ctx = get_dist_manager().current_context() + dist_cfg = dist_ctx.dist_config + _, tp_mode = dist_cfg.get_tp_by_layer('moe') + tp = 1 if enable_ep else tp + tp_rank = 0 if enable_ep else tp_rank + all_reduce = all_reduce if tp > 1 else False + all_reduce = False if enable_ep else all_reduce + + self.tp = tp + self.tp_rank = tp_rank + self.tp_mode = tp_mode + self.all_reduce = all_reduce + self.tp_group = dist_ctx.moe_tp_group.gpu_group + def update_weights(self): """Update weights.""" gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight) @@ -283,8 +275,10 @@ def update_weights(self): self.down.update_weight(down_weights) def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, topk_weights, topk_ids, - self.enable_ep) + hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, + topk_weights, + topk_ids, + group=self.tp_group) ret = self.impl.forward(hidden_states, topk_weights, @@ -296,7 +290,7 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ self.expert_list, act_func=self.act_func) if self.all_reduce: - ret = _moe_reduce(ret, self.enable_ep) + ret = _moe_reduce(ret, rank=self.tp_rank, tp_mode=self.tp_mode, group=self.tp_group) return ret @@ -544,6 +538,7 @@ def __init__(self, device = torch.device('cpu') dtype = torch.float16 if dtype is None else dtype self.block_size = 128 + self.init_tp_args(all_reduce, enable_ep) dist_ctx = get_dist_manager().current_context() self.ep_size, rank = get_ep_world_rank() impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8) @@ -594,12 +589,25 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank('moe') - if world_size == 1: - all_reduce = False - self.all_reduce = all_reduce self.act_func = act_func + def init_tp_args(self, all_reduce: bool, enable_ep: bool): + """Init tp args.""" + tp, tp_rank = get_tp_world_rank('moe') + dist_ctx = get_dist_manager().current_context() + dist_cfg = dist_ctx.dist_config + _, tp_mode = dist_cfg.get_tp_by_layer('moe') + tp = 1 if enable_ep else tp + tp_rank = 0 if enable_ep else tp_rank + all_reduce = all_reduce if tp > 1 else False + all_reduce = False if enable_ep else all_reduce + + self.tp = tp + self.tp_rank = tp_rank + self.tp_mode = tp_mode + self.all_reduce = all_reduce + self.tp_group = dist_ctx.moe_tp_group.gpu_group + def update_weights(self): """Update weights.""" (gate_up_weights, down_weights, gate_up_scale, @@ -682,8 +690,10 @@ def dispatch(self, state: Dict): else: recv_state['hook'] = hook else: # MoeType.Default - hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], state['topk_weights'], - state['topk_idx'], False) + hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + group=self.tp_group) recv_state = { 'hidden_states': hidden_states, 'topk_idx': topk_idx, @@ -769,7 +779,10 @@ def combine(self, state: Dict): out_state['hook'] = hook else: # MoeType.Default if self.all_reduce: - state['hidden_states'] = _moe_reduce(state['hidden_states'], False) + state['hidden_states'] = _moe_reduce(state['hidden_states'], + rank=self.tp_rank, + tp_mode=self.tp_mode, + group=self.tp_group) out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']} return out_state diff --git a/lmdeploy/serve/openai/launch_server.py b/lmdeploy/serve/openai/launch_server.py index 94121b6a94..2d2fd56c3f 100644 --- a/lmdeploy/serve/openai/launch_server.py +++ b/lmdeploy/serve/openai/launch_server.py @@ -97,11 +97,10 @@ def launch_server(num_nodes: int, ep = backend_config.ep assert dp > 1, f'only support dp > 1, but give dp={dp}' assert tp > 1 or ep > 1, f'only support tp > 1 or ep > 1, but given tp={tp} ep={ep}' - assert tp <= dp, f'only support tp <= dp, but give tp={tp}, dp={dp}' - assert ep <= dp, f'only support ep <= dp, but give tp={ep}, dp={dp}' + num_devices = max(dp, tp, ep) dp_per_node = dp // num_nodes - tp_per_dp = 1 # each dp uses one rank + tp_per_dp = num_devices // dp http_or_https = 'https' if kwargs.get('ssl', False) else 'http' model_name = kwargs.get('model_name', None) if model_name is None: From 7552b6fd3a0268116178c6c93ae59e0c6a30b386 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 23 Sep 2025 21:04:00 +0800 Subject: [PATCH 03/15] refactor --- lmdeploy/pytorch/engine/model_agent.py | 180 ++++++++++------------ lmdeploy/pytorch/nn/linear/blocked_fp8.py | 3 +- 2 files changed, 85 insertions(+), 98 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 1ac733d6d6..695e9c1b95 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -3,7 +3,7 @@ import base64 import functools import time -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from dataclasses import dataclass, fields from multiprocessing.reduction import ForkingPickler from os import getenv @@ -347,6 +347,7 @@ def __init__(self, world_size = self.dist_config.world_size self.tp = tp self.world_size = world_size + self.need_output = rank % self.dist_config.attn_tp == 0 self.patched_model = None self.cache_engine = None @@ -561,18 +562,69 @@ def _push_output(self, output: BatchedOutputs): self._out_que.put_nowait((output, event)) @contextmanager - def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext, enable: bool = True): + def _broadcast_next_token(self, next_token_ids: torch.Tensor, enable: bool = True): if not enable: yield return - dist_ctx = get_dist_manager().current_context() - tp_gpu_group = dist_ctx.attn_tp_group.gpu_group + tp_gpu_group = self.dist_ctx.attn_tp_group.gpu_group rank = dist.get_global_rank(tp_gpu_group, 0) handle = dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True) yield handle.wait() + @record_function('prepare_dp') + async def _prepare_dp(self, inputs: ModelInputs, sync_long_context: bool, is_dummy: bool): + """Prepare dp.""" + world_size = self.dist_config.world_size + is_decoding = inputs.is_decoding + + # gather dp forward metadata + batch_size = inputs.seq_length.numel() + dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)] + # check enable_microbatch + if self.enable_microbatch: + tokens_num = inputs.input_ids.numel() + if is_decoding: + enable_microbatch = batch_size >= \ + self.enable_microbatch_decode_batchsize_threshold + else: + enable_microbatch = batch_size >= \ + self.enable_microbatch_prefill_batchsize_threshold and \ + tokens_num >= self.enable_microbatch_prefill_token_threshold + dp_forward_meta.append(int(enable_microbatch)) + gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device='cuda') + gathered_meta = (await gathered_meta.async_wait()).cpu() + + # check is_decoding + all_is_decoding = gathered_meta[:, 0] + assert all_is_decoding.sum().item() in [0, world_size] + + # check if all inputs are dummy inputs + is_all_dummy = gathered_meta[:, 1].all() + if is_all_dummy: + return inputs, sync_long_context, is_all_dummy + + if is_decoding: + all_batch_sizes = gathered_meta[:, 2] + padding_batch_size = all_batch_sizes.max().item() + meta = self.patched_model.get_meta() + meta.padding_batch_size = padding_batch_size + logger.debug(f'padding_batch_size={padding_batch_size}') + else: + all_sync_flags = gathered_meta[:, 3].bool() + sync_long_context = all_sync_flags.any() + logger.debug(f'sync_long_context={sync_long_context}') + + # update if enable_microbatch + if self.enable_microbatch and gathered_meta[:, 4].all(): + inputs.enable_microbatch = True + + # update dp meta + inputs.build_dp_meta() + inputs = self.patched_model.update_inputs(inputs) + return inputs, sync_long_context, is_all_dummy + async def _async_step_background( self, inputs: ModelInputs, @@ -587,7 +639,6 @@ async def _async_step_background( extra_inputs: ExtraInputs = None, ): """Asyc forward task.""" - dist_ctx = get_dist_manager().current_context() @record_function('update_inputs_for_next_step') def __update_inputs(next_token_ids, model_metas, extra_inputs): @@ -600,90 +651,29 @@ def __update_inputs(next_token_ids, model_metas, extra_inputs): extra_inputs=extra_inputs, ) - @asynccontextmanager - async def __prepare_dp(): - """Prepare dp.""" - if dp == 1: - yield - return - - nonlocal inputs, sync_long_context, is_all_dummy - world_size = self.dist_config.world_size - - # gather dp forward metadata - batch_size = inputs.seq_length.numel() - dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)] - # check enable_microbatch - if self.enable_microbatch: - tokens_num = inputs.input_ids.numel() - if is_decoding: - enable_microbatch = batch_size >= \ - self.enable_microbatch_decode_batchsize_threshold - else: - enable_microbatch = batch_size >= \ - self.enable_microbatch_prefill_batchsize_threshold and \ - tokens_num >= self.enable_microbatch_prefill_token_threshold - dp_forward_meta.append(int(enable_microbatch)) - gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device='cuda') - - yield - - gathered_meta = (await gathered_meta.async_wait()).cpu() - - # check is_decoding - all_is_decoding = gathered_meta[:, 0] - assert all_is_decoding.sum().item() in [0, world_size] - - # check if all inputs are dummy inputs - is_all_dummy = gathered_meta[:, 1].all() - if is_all_dummy: - return - - if is_decoding: - all_batch_sizes = gathered_meta[:, 2] - padding_batch_size = all_batch_sizes.max().item() - meta = self.patched_model.get_meta() - meta.padding_batch_size = padding_batch_size - logger.debug(f'padding_batch_size={padding_batch_size}') - else: - all_sync_flags = gathered_meta[:, 3].bool() - sync_long_context = all_sync_flags.any() - logger.debug(f'sync_long_context={sync_long_context}') - - # update if enable_microbatch - if self.enable_microbatch and gathered_meta[:, 4].all(): - inputs.enable_microbatch = True - - # update dp meta - inputs.build_dp_meta() - inputs = self.patched_model.update_inputs(inputs) - # dist tools dist_ctx = get_dist_manager().current_context() dist_config = dist_ctx.dist_config - rank = dist_ctx.rank + rank = self.rank tp = dist_config.attn_tp dp = dist_config.dp sync_long_context = False if dp == 1 else sync_long_context is_decoding = inputs.is_decoding + # is_all_dummy would be updated in __prepare_dp + if dp > 1: + inputs, sync_long_context, is_all_dummy = await self._prepare_dp(inputs, sync_long_context, is_dummy) + + # skip dummy forward. + if is_all_dummy: + logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') + return + logger.debug(f' rank[{rank}]: ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)} ' f'is_decoding={inputs.is_decoding}') - # is_all_dummy would be updated in __prepare_dp - is_all_dummy = False - async with __prepare_dp(): - pass - - need_output = dp > 1 or rank % tp == 0 - - # skip dummy forward. - if is_all_dummy: - logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') - return - cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) for idx in range(loop_count): # inference @@ -693,12 +683,11 @@ async def __prepare_dp(): return_logits=return_logits, sync_long_context=sync_long_context, ) - logits = output['logits'] - logits = logits[0] # [bs, seq, prob] -> [seq, prob] - seq_length = inputs.seq_length + logits = output['logits'][0] # [bs, seq, prob] -> [seq, prob] seq_length = output.get('seq_length', inputs.seq_length) last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length) + model_metas = output.get('model_metas') # output empty for dummy inputs if is_dummy: @@ -707,12 +696,12 @@ async def __prepare_dp(): need_broadcast_next = (tp > 1 and idx < loop_count - 1) # sampling and stopping - if need_output: + if self.need_output: logger.debug(f' rank[{rank}]: Sampling [{idx}].') # sampling next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) - with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): + with self._broadcast_next_token(next_token_ids, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') # post sampling @@ -724,35 +713,32 @@ async def __prepare_dp(): sampling_inputs.stop_words, inputs=inputs, extra_inputs=extra_inputs) + + # send output + logger.debug(f' rank[{rank}]: Output [{idx}]') + extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs) + self._push_output( + BatchedOutputs(next_token_ids=next_token_ids, + logits=logits if return_logits else None, + stopped=stopped, + stop_pos=stop_pos, + model_metas=model_metas, + logprobs=logprobs, + extra_outputs=extra_outputs)) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0)) - logprobs = None # broadcast next token for TP > 1 - with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): + with self._broadcast_next_token(next_token_ids, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') # post sampling next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, extra_inputs) - # send output - model_metas = output.get('model_metas') - if need_output: - logger.debug(f' rank[{rank}]: Output [{idx}]') - extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs) - self._push_output( - BatchedOutputs(next_token_ids=next_token_ids, - logits=logits if return_logits else None, - stopped=stopped, - stop_pos=stop_pos, - model_metas=model_metas, - logprobs=logprobs, - extra_outputs=extra_outputs)) - # update for next loop if is_decoding and idx < loop_count - 1: inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs) diff --git a/lmdeploy/pytorch/nn/linear/blocked_fp8.py b/lmdeploy/pytorch/nn/linear/blocked_fp8.py index 163d4ee1f4..af67358bc3 100644 --- a/lmdeploy/pytorch/nn/linear/blocked_fp8.py +++ b/lmdeploy/pytorch/nn/linear/blocked_fp8.py @@ -4,6 +4,7 @@ import torch from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader from ..quant_utils import quant_blocked_fp8 @@ -141,7 +142,7 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - if self.tp_mode: + if self.tp_mode == TPMode.DP_TP: rank = self.tp_rank return self.impl.forward(x, self.weight, From 1068fa4e0914ae9c9375a4a646d39b7e582f1d8a Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Sep 2025 12:33:10 +0800 Subject: [PATCH 04/15] fix --- lmdeploy/pytorch/engine/model_agent.py | 41 +++++++++++++++++++------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 695e9c1b95..141af961a6 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1083,6 +1083,8 @@ def __init__(self, model_agent: BaseModelAgent): self._next_inputs = None self._is_decoding = False self._ready_event = torch.cuda.Event() + self._attn_tp = self.dist_ctx.dist_config.attn_tp + self._attn_tp_cpu_group = self.dist_ctx.attn_tp_group.cpu_group def _make_dummy_forward_inputs(self): """Make dummy forward inputs.""" @@ -1110,6 +1112,34 @@ def _update_is_decoding(self, forward_inputs): if self.cache_config.role != EngineRole.Prefill: self._is_decoding = not self._is_decoding + async def _try_get_inputs(self): + # try get inputs + need_dummy = True + forward_inputs = None + try: + forward_inputs = await asyncio.wait_for(self._in_que.get(), timeout=0.02) + model_inputs = forward_inputs['inputs'] + if model_inputs.is_decoding != self._is_decoding: + self._next_inputs = forward_inputs + else: + need_dummy = False + except asyncio.TimeoutError: + pass + + # sync between attn tp + if self._attn_tp > 1: + all_need_dummy = torch.zeros((self._attn_tp, ), dtype=torch.bool) + dist.all_gather_into_tensor(all_need_dummy, + torch.tensor((need_dummy, )), + group=self._attn_tp_cpu_group, + async_op=False) + has_real = not (all_need_dummy.all().item()) + if has_real and need_dummy: + need_dummy = False + forward_inputs = await self._in_que.get() + + return forward_inputs, need_dummy + async def get(self): """get.""" if self._next_inputs is not None: @@ -1123,16 +1153,7 @@ async def get(self): await asyncio.sleep(0.001) # try get inputs - need_dummy = True - try: - forward_inputs = await asyncio.wait_for(self._in_que.get(), timeout=0.02) - model_inputs = forward_inputs['inputs'] - if model_inputs.is_decoding != self._is_decoding: - self._next_inputs = forward_inputs - else: - need_dummy = False - except asyncio.TimeoutError: - pass + forward_inputs, need_dummy = await self._try_get_inputs() # make dummy inputs if need_dummy: From 6e300e6c9b34ddaac90ab0b140c255a24f3c28d5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Sep 2025 15:08:05 +0800 Subject: [PATCH 05/15] fix --- lmdeploy/pytorch/engine/model_agent.py | 73 +++++++++++++++++++------- 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 141af961a6..9cb663670f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1083,7 +1083,6 @@ def __init__(self, model_agent: BaseModelAgent): self._next_inputs = None self._is_decoding = False self._ready_event = torch.cuda.Event() - self._attn_tp = self.dist_ctx.dist_config.attn_tp self._attn_tp_cpu_group = self.dist_ctx.attn_tp_group.cpu_group def _make_dummy_forward_inputs(self): @@ -1112,31 +1111,67 @@ def _update_is_decoding(self, forward_inputs): if self.cache_config.role != EngineRole.Prefill: self._is_decoding = not self._is_decoding - async def _try_get_inputs(self): - # try get inputs - need_dummy = True - forward_inputs = None + async def _broadcast_has_inputs(self, has_inputs: bool = False): + """Broadcast has inputs.""" + attn_tp_group = self.dist_ctx.attn_tp_group + attn_tp = self.dist_ctx.dist_config.attn_tp + if attn_tp == 1: + return has_inputs + + group = attn_tp_group.cpu_group + rank = dist.get_global_rank(group, 0) + has_inputs = torch.tensor((has_inputs, )) + handle = dist.broadcast(has_inputs, src=rank, group=group, async_op=True) + future = handle.get_future() + while not future.done(): + await asyncio.sleep(0) + return has_inputs.item() + + async def _get_inputs_rank0(self): + """Try get inputs rank0.""" try: forward_inputs = await asyncio.wait_for(self._in_que.get(), timeout=0.02) + except asyncio.TimeoutError: + forward_inputs = None + + has_inputs = forward_inputs is not None + await self._broadcast_has_inputs(has_inputs) + return forward_inputs + + async def _get_inputs_rankn(self): + """Try get inputs rankn.""" + # broadcast + has_inputs = await self._broadcast_has_inputs() + + # try get inputs + if has_inputs: + forward_inputs = await self._in_que.get() + else: + forward_inputs = None + return forward_inputs + + async def _try_get_inputs(self): + """Try get inputs.""" + + attn_tp_group = self.dist_ctx.attn_tp_group + tp_rank = attn_tp_group.rank + + # initialize output + forward_inputs = None + need_dummy = True + + # get inputs from in_que. Rank 1 will not gather if rank 0 does not read inputs. + if tp_rank == 0: + forward_inputs = await self._get_inputs_rank0() + else: + forward_inputs = await self._get_inputs_rankn() + + if forward_inputs is not None: model_inputs = forward_inputs['inputs'] if model_inputs.is_decoding != self._is_decoding: self._next_inputs = forward_inputs else: need_dummy = False - except asyncio.TimeoutError: - pass - - # sync between attn tp - if self._attn_tp > 1: - all_need_dummy = torch.zeros((self._attn_tp, ), dtype=torch.bool) - dist.all_gather_into_tensor(all_need_dummy, - torch.tensor((need_dummy, )), - group=self._attn_tp_cpu_group, - async_op=False) - has_real = not (all_need_dummy.all().item()) - if has_real and need_dummy: - need_dummy = False - forward_inputs = await self._in_que.get() return forward_inputs, need_dummy From 6a08f868f2498f0691652c4f7c7ac19d6c293503 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Sep 2025 15:53:13 +0800 Subject: [PATCH 06/15] vis --- lmdeploy/pytorch/tools/utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lmdeploy/pytorch/tools/utils.py b/lmdeploy/pytorch/tools/utils.py index 98d42772c0..899dfc42fa 100644 --- a/lmdeploy/pytorch/tools/utils.py +++ b/lmdeploy/pytorch/tools/utils.py @@ -192,3 +192,24 @@ def _print_meta(out: Response): print(colored('─' * (term_size), border_color, attrs=['dark'])) else: print(colored('━' * term_size, border_color)) + + +def visualize_chat_completions(outputs, enable_meta: bool = True): + """Visualize chat completions.""" + from openai.types.chat import ChatCompletion + + from lmdeploy.messages import Response + if isinstance(outputs, ChatCompletion): + outputs = [outputs] + + resps = [] + for out in outputs: + assert isinstance(out, ChatCompletion) + choice = out.choices[0] + resp = Response(text=choice.message.content, + input_token_len=out.usage.prompt_tokens, + generate_token_len=out.usage.completion_tokens, + finish_reason=choice.finish_reason) + resps.append(resp) + + return visualize_pipe_out(resps, enable_meta=enable_meta) From 6d2e839c8f5c72281a0c47397547117255aad1b8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 24 Sep 2025 20:54:43 +0800 Subject: [PATCH 07/15] fix pd --- lmdeploy/pytorch/engine/engine.py | 4 +++- lmdeploy/serve/proxy/proxy.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 2179b5a99f..d3c963370f 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -602,7 +602,9 @@ def __update_max_new_tokens(msg): sampling_param = msg.sampling_param max_new_tokens = sampling_param.max_new_tokens num_all_tokens = msg.num_valid_ids - if max_new_tokens + num_all_tokens > max_session_len: + if self.engine_config.role == EngineRole.Prefill: + sampling_param.max_new_tokens = 1 + elif max_new_tokens + num_all_tokens > max_session_len: logger.warning( f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. ' f'Update max_new_tokens={max_session_len - num_all_tokens}.') diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 4f88340593..11f0caec47 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -597,6 +597,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # Prefill prefill_request_dict = copy.deepcopy(request_dict) prefill_request_dict['max_tokens'] = 1 + prefill_request_dict['max_completion_tokens'] = 1 prefill_request_dict['stream'] = False prefill_request_dict['with_cache'] = True prefill_request_dict['preserve_cache'] = True From cb0e87587e7e2e5fe5dc04f24afe658df27c4871 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Sep 2025 11:22:08 +0800 Subject: [PATCH 08/15] optimize gather --- lmdeploy/pytorch/distributed.py | 49 ++++++++++++++++++++++++------ lmdeploy/pytorch/model_inputs.py | 18 ++++++----- lmdeploy/pytorch/nn/linear/base.py | 9 ++++-- lmdeploy/pytorch/nn/moe.py | 6 ++-- 4 files changed, 60 insertions(+), 22 deletions(-) diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 05f9391c3d..1269e6a36b 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -9,7 +9,7 @@ from torch import distributed as dist from torch.distributed import ProcessGroup, ReduceOp # noqa: F401 -from .config import DistConfig +from .config import DistConfig, TPMode @dataclass @@ -20,6 +20,7 @@ class DistGroup: gpu_group: dist.ProcessGroup = None cpu_groups: List[dist.ProcessGroup] = None gpu_groups: List[dist.ProcessGroup] = None + gpu_gather_group: dist.ProcessGroup = None def close(self): """Close groups.""" @@ -40,29 +41,48 @@ def _build_tp_group_impl(tp: int, world_size: int, timeout: timedelta, cpu_backend: str = 'gloo', - ccl_backend: str = 'nccl'): + ccl_backend: str = 'nccl', + attn_tp: int = 1, + tp_mode: TPMode = TPMode.DEFAULT): """Build tp group.""" assert tp > 1 tp_rank = rank % tp tp_group_id = rank // tp + gather_group_id = (rank - tp_group_id * tp) % attn_tp ranks = range(world_size) tp_gpu_groups = [] tp_cpu_groups = [] + gather_groups = [] for start in range(0, world_size, tp): tp_ranks = ranks[start:start + tp] group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend) tp_gpu_groups.append(group) cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend) tp_cpu_groups.append(cpu_group) + + # create gather group + if tp_mode == TPMode.DP_TP and attn_tp != tp: + for g_start in range(start, start + attn_tp): + g_ranks = ranks[g_start:(g_start + tp):attn_tp] + gather_group = dist.new_group(ranks=g_ranks, timeout=timeout, backend=ccl_backend) + gather_groups.append(gather_group) tp_gpu_group = tp_gpu_groups[tp_group_id] tp_cpu_group = tp_cpu_groups[tp_group_id] + if tp_mode == TPMode.DP_TP: + if attn_tp == tp: + gather_group = tp_gpu_group + else: + gather_group = gather_groups[gather_group_id] + else: + gather_group = None return DistGroup( rank=tp_rank, cpu_group=tp_cpu_group, gpu_group=tp_gpu_group, cpu_groups=tp_cpu_groups, gpu_groups=tp_gpu_groups, + gpu_gather_group=gather_group, ) @@ -85,6 +105,8 @@ def _build_attn_tp_group(context: 'DistContext', timeout=timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend, + attn_tp=tp, + tp_mode=TPMode.DEFAULT, ) context.attn_tp_group = dist_group @@ -113,6 +135,8 @@ def _build_mlp_tp_group(context: 'DistContext', timeout=timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend, + attn_tp=dist_config.attn_tp, + tp_mode=dist_config.mlp_tp_mode, ) context.mlp_tp_group = dist_group @@ -146,6 +170,8 @@ def _build_moe_tp_group(context: 'DistContext', timeout=timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend, + attn_tp=dist_config.attn_tp, + tp_mode=dist_config.moe_tp_mode, ) context.moe_tp_group = dist_group @@ -330,12 +356,9 @@ def get_process_group(device: str = None): return dist.GroupMember.WORLD -def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'): - """Get tp group.""" +def get_dist_group(layer_type: str = 'attn'): + """Get dist group.""" ctx = get_dist_manager().current_context() - - _check_group_device(device) - if layer_type == 'attn': tp_group = ctx.attn_tp_group elif layer_type == 'mlp': @@ -344,6 +367,13 @@ def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'): tp_group = ctx.moe_tp_group else: raise RuntimeError(f'Unknown layer type: {layer_type}') + return tp_group + + +def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'): + """Get tp group.""" + _check_group_device(device) + tp_group = get_dist_group(layer_type) if tp_group is None: return None @@ -414,8 +444,9 @@ def gather_by_tp_sizes(x: torch.Tensor, tp_sizes: List[int], group: Optional[dis def reduce_scatter_by_tp_sizes(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup): """Reduce scatter.""" - outs = out.split(tp_sizes, -2) + attn_tp = get_dist_manager().current_config().attn_tp + outs = list(out.split(tp_sizes, -2)) + outs = [item for item in outs for _ in range(attn_tp)] out = outs[rank] - outs = list(outs) dist.reduce_scatter(out, outs, group=group) return out diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index c8d8aa0c1c..6b9ba84dc6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -22,12 +22,14 @@ class DPMeta: moe_tp_sizes: List[int] = None @staticmethod - def _gather_tp_sizes(tp: int, attn_tp: int, seqlen: int, layer_type: str): + def _gather_tp_sizes(tp: int, seqlen: int, dist_ctx: dist.DistContext, layer_type: str): """Gather tp size.""" + attn_tp = dist_ctx.dist_config.attn_tp if tp > 1 and tp != attn_tp: - tp_sizes = [None for _ in range(tp)] - tp_group = dist.get_tp_group('gpu', layer_type=layer_type) - dist.all_gather_object(tp_sizes, seqlen, group=tp_group) + dist_group = dist.get_dist_group(layer_type=layer_type) + gather_group = dist_group.gpu_gather_group + tp_sizes = [None for _ in range(gather_group.size())] + dist.all_gather_object(tp_sizes, seqlen, group=gather_group) else: tp_sizes = [seqlen] return tp_sizes @@ -35,17 +37,17 @@ def _gather_tp_sizes(tp: int, attn_tp: int, seqlen: int, layer_type: str): @classmethod def build(cls, seqlen: int): """Get dp meta.""" - dist_config = dist.get_dist_manager().current_config() - attn_tp = dist_config.attn_tp + dist_ctx = dist.get_dist_manager().current_context() + dist_config = dist_ctx.dist_config mlp_tp = dist_config.mlp_tp - tp_sizes = cls._gather_tp_sizes(mlp_tp, attn_tp, seqlen, layer_type='mlp') + tp_sizes = cls._gather_tp_sizes(mlp_tp, seqlen, dist_ctx, layer_type='mlp') moe_tp = dist_config.moe_tp if moe_tp == mlp_tp: moe_tp_sizes = tp_sizes else: - moe_tp_sizes = cls._gather_tp_sizes(moe_tp, attn_tp, seqlen, layer_type='moe') + moe_tp_sizes = cls._gather_tp_sizes(moe_tp, seqlen, dist_ctx, layer_type='moe') return DPMeta(tp_sizes=tp_sizes, moe_tp_sizes=moe_tp_sizes) diff --git a/lmdeploy/pytorch/nn/linear/base.py b/lmdeploy/pytorch/nn/linear/base.py index 90057a920b..c69f8a580a 100644 --- a/lmdeploy/pytorch/nn/linear/base.py +++ b/lmdeploy/pytorch/nn/linear/base.py @@ -6,7 +6,7 @@ from torch import nn from lmdeploy.pytorch.config import TPMode -from lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_manager, get_tp_group, get_tp_world_rank, +from lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_group, get_dist_manager, get_tp_world_rank, reduce_scatter_by_tp_sizes) from lmdeploy.pytorch.model_inputs import get_step_ctx_manager @@ -55,12 +55,15 @@ def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: self.tp_rank = rank self.tp = tp self.tp_mode = tp_mode - self.tp_group = get_tp_group(layer_type=layer_type) + dist_group = get_dist_group(layer_type=layer_type) + self.tp_group = dist_group.gpu_group + self.gather_group = dist_group.gpu_gather_group else: self.tp_rank = 0 self.tp = 1 self.tp_mode = TPMode.DEFAULT self.tp_group = None + self.gather_group = None self._tp_args_initialized = True @@ -99,7 +102,7 @@ def forward(self, x): tp_sizes = dp_meta.tp_sizes if self.dp_gather: - x = gather_by_tp_sizes(x, tp_sizes, group=self.tp_group) + x = gather_by_tp_sizes(x, tp_sizes, group=self.gather_group) if len(self.lora_adapters) == 0: return self._forward_default(x, self.all_reduce, tp_sizes) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 13529db12c..ea69502665 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -267,6 +267,7 @@ def init_tp_args(self, all_reduce: bool, enable_ep: bool): self.tp_mode = tp_mode self.all_reduce = all_reduce self.tp_group = dist_ctx.moe_tp_group.gpu_group + self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group def update_weights(self): """Update weights.""" @@ -278,7 +279,7 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, topk_weights, topk_ids, - group=self.tp_group) + group=self.gather_group) ret = self.impl.forward(hidden_states, topk_weights, @@ -607,6 +608,7 @@ def init_tp_args(self, all_reduce: bool, enable_ep: bool): self.tp_mode = tp_mode self.all_reduce = all_reduce self.tp_group = dist_ctx.moe_tp_group.gpu_group + self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group def update_weights(self): """Update weights.""" @@ -693,7 +695,7 @@ def dispatch(self, state: Dict): hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], state['topk_weights'], state['topk_idx'], - group=self.tp_group) + group=self.gather_group) recv_state = { 'hidden_states': hidden_states, 'topk_idx': topk_idx, From ce0d460231542ed4eedac7d826e1b5001c09714a Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Sep 2025 12:49:31 +0800 Subject: [PATCH 09/15] expose layer tp --- lmdeploy/messages.py | 6 ++++ lmdeploy/pytorch/config.py | 59 ++++++++++++++++++++++++------- lmdeploy/pytorch/engine/engine.py | 9 ++--- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 73c02e2914..4441e98c99 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -289,6 +289,9 @@ class PytorchEngineConfig: session_len (int): Max session length. Default None. max_batch_size (int): Max batch size. If it is not specified, the engine will automatically set it according to the device + attn_tp_size (int): tp size for attention, only works for dp>1 + mlp_tp_size (int): tp size for mlp, only works for dp>1 + moe_tp_size (int): tp size for moe, only works for dp>1 cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache. For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory @@ -350,6 +353,9 @@ class PytorchEngineConfig: ep: int = 1 session_len: int = None max_batch_size: int = None + attn_tp_size: int = None + mlp_tp_size: int = None + moe_tp_size: int = None cache_max_entry_count: float = 0.8 prefill_interval: int = 16 block_size: int = 64 diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 4124783987..5fded92303 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -117,9 +117,9 @@ class DistConfig: # tp tp: int = 1 # default tp, equal to attn_tp - attn_tp: int = 1 # tp for attention - mlp_tp: int = 1 # tp for mlp - moe_tp: int = 1 # tp for moe + attn_tp: int = None # tp for attention + mlp_tp: int = None # tp for mlp + moe_tp: int = None # tp for moe # tp mode mlp_tp_mode: TPMode = TPMode.DEFAULT @@ -134,18 +134,37 @@ def __post_init__(self): tp = self.tp ep = self.ep + # ignore layer to for dp==1 + if dp == 1: + self.mlp_tp = None + self.attn_tp = None + self.moe_tp = None + + # mlp and moe tp + self.mlp_tp = self.mlp_tp or tp + self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp) + # world_size - world_size = ep if ep > 1 else tp - assert world_size >= dp and world_size % dp == 0 - assert world_size >= tp and world_size % tp == 0 - assert world_size >= ep and world_size % ep == 0 + world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp) self.world_size = world_size - - # tp - self.attn_tp = self.world_size // dp - self.mlp_tp = tp - self.moe_tp = 1 if ep > 1 else tp + assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}') + assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}') + assert (world_size >= self.mlp_tp + and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}') + assert (world_size >= self.moe_tp + and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}') + + # attn tp + self.attn_tp = self.attn_tp or self.world_size // dp self.tp = self.attn_tp + if self.mlp_tp > 1: + assert (self.mlp_tp >= self.attn_tp + and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}') + if self.moe_tp > 1: + assert (self.moe_tp >= self.attn_tp + and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}') + assert (world_size >= self.attn_tp + and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}') # tp mode self.mlp_tp_mode = TPMode.DEFAULT if (self.mlp_tp in [1, self.attn_tp]) else TPMode.DP_TP @@ -165,6 +184,22 @@ def get_tp_by_layer(self, layer_type: str): else: raise ValueError(f'Unknown layer type: {layer_type}') + @classmethod + def from_engine_config(cls, engine_config: PytorchEngineConfig): + """From engine config.""" + dist_config = cls( + dp=engine_config.dp, + ep=engine_config.ep, + dp_rank=engine_config.dp_rank, + enable_microbatch=engine_config.enable_microbatch, + enable_eplb=engine_config.enable_eplb, + tp=engine_config.tp, + attn_tp=engine_config.attn_tp_size, + mlp_tp=engine_config.mlp_tp_size, + moe_tp=engine_config.moe_tp_size, + ) + return dist_config + def _override_hf_config_dict(hf_config: dict, key: str, hf_overrides): """Override hf_config dict.""" diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index d3c963370f..3133b3827d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -133,12 +133,7 @@ def _build_backend_config(engine_config: PytorchEngineConfig): def _build_dist_config(engine_config: PytorchEngineConfig): """Build dist config.""" - dist_config = DistConfig(dp=engine_config.dp, - tp=engine_config.tp, - ep=engine_config.ep, - dp_rank=engine_config.dp_rank, - enable_microbatch=engine_config.enable_microbatch, - enable_eplb=engine_config.enable_eplb) + dist_config = DistConfig.from_engine_config(engine_config=engine_config) return dist_config @@ -658,7 +653,7 @@ def model_config(self) -> ModelConfig: @property def gpu_count(self): - return self.tp * self.dp + return self.dist_config.world_size @property def torch_int_dtype(self): From 756d59a559e8fc3a4bb0ee85fcd110bc8f7a0996 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Sep 2025 21:30:09 +0800 Subject: [PATCH 10/15] ep + attn tp allgather --- lmdeploy/pytorch/backends/cuda/moe.py | 44 ++++++++++++++++++++++++++ lmdeploy/pytorch/engine/model_agent.py | 8 ++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index a1394f28bd..59831d8146 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -619,6 +619,45 @@ def ep_expert_list(self, world_size: int, rank: int): else: return super().ep_expert_list(world_size=world_size, rank=rank) + def _split_inputs_by_attn_tp( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + ): + """Split input by attn tp.""" + dist_ctx = get_dist_manager().current_context() + attn_tp = dist_ctx.dist_config.attn_tp + attn_rank = dist_ctx.attn_tp_group.rank + num_states = hidden_states.size(0) + + if attn_tp == 1 or attn_tp > num_states: + return hidden_states, topk_weights, topk_ids, None + + # split size + base = num_states // attn_tp + remain = num_states % attn_tp + split_size = [base + 1] * remain + [base] * (attn_tp - remain) + + # split inputs + hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank] + topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank] + topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank] + + return hidden_states, topk_weights, topk_ids, split_size + + def _gather_outputs_by_attn_tp(self, out_states: torch.Tensor, split_size: List[int]): + """Gather output by attn tp.""" + if split_size is None: + return out_states + + dist_ctx = get_dist_manager().current_context() + gpu_group = dist_ctx.attn_tp_group.gpu_group + new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1])) + new_out_states_list = list(new_out_states.split(split_size, dim=0)) + dist.all_gather(new_out_states_list, out_states, group=gpu_group) + return new_out_states + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -633,12 +672,17 @@ def forward(self, act_func: Callable = None, **kwargs): """forward.""" + hidden_states, topk_weights, topk_ids, split_size = self._split_inputs_by_attn_tp( + hidden_states, topk_weights, topk_ids) + topk_weights = self.do_renormalize(topk_weights) step_ctx = get_step_ctx_manager().current_context() low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm moe = self.fusedmoe_build(low_latency_mode) out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, down_scale, expert_list) + + out_states = self._gather_outputs_by_attn_tp(out_states, split_size) return out_states def do_renormalize(self, topk_weights): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 9cb663670f..66c829917f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -392,7 +392,7 @@ def get_free_mem(self): def warmup(self): """warmup.""" # TODO: disable for now, do not remove the comments. - with self.all_context(): + with self.all_context(), torch.cuda.stream(self.stream): max_batches = self.cache_config.max_batches num_tokens = max_batches dp = self.dist_config.dp @@ -404,7 +404,10 @@ def warmup(self): vocab_size=self.model_config.vocab_size) if dp > 1: inputs.build_dp_meta() + logger.debug('Warmup prefill start.') self._forward_impl(inputs) + torch.cuda.synchronize() + logger.debug('Warmup prefill done.') # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() @@ -416,7 +419,10 @@ def warmup(self): vocab_size=self.model_config.vocab_size) if dp > 1: inputs.build_dp_meta() + logger.debug(f'Warmup decoding num_tokens={num_tokens} start.') self._forward_impl(inputs) + torch.cuda.synchronize() + logger.debug(f'Warmup decoding num_tokens={num_tokens} done.') def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): """Slice outputs.""" From d07c2f5886868726c5099bda0c123f85d1b1ca20 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Sep 2025 15:11:31 +0800 Subject: [PATCH 11/15] fix not aligned weight --- lmdeploy/pytorch/nn/moe.py | 63 +++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index ea69502665..11d7644df8 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -54,6 +54,16 @@ def _update_args(hidden_dim: int, ffn_dim: int): return hidden_dim, ffn_dim +def _split_size(size: int, world_size: int, align: int): + size = size // align + assert size >= world_size + base = size // world_size + remain = size % world_size + split_size = [base + 1] * remain + [base] * (world_size - remain) + split_size = [s * align for s in split_size] + return split_size + + class LinearWeights(nn.Module): """Fused moe linear weights.""" @@ -460,7 +470,7 @@ def __init__(self, device=device) weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) self.register_parameter('weight_scale_inv', weight_scale_inv) - self.weight._base_weight_loader = self.weight.weight_loader + self.weight._base_weight_loader = self.weight_loader_tp self.weight.weight_loader = self.weight_loader_with_quant if self.ep: @@ -485,6 +495,41 @@ def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch for expert_id in expert_ids: self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) + def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int): + """Chunk with align.""" + split_size = _split_size(weight.size(dim), world_size, align) + return weight.split(split_size, dim=dim)[rank] + + def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): + """Weight loader.""" + world_size, rank = get_tp_world_rank('moe') + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'down': + param_data = param.data[expert_id] + # weight is not contiguous, chunk and copy in cpu is slow + weight = loaded_weight.to(param_data.device) + if weight.dim() > 1: + weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size) + elif weight.dim() == 1 and rank != 0: + # bias with rank>0 should be 0 + weight = torch.zeros_like(weight) + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader scale tp.""" @@ -493,14 +538,14 @@ def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch half_out = self.half_out // block_size if shard_id == 'gate': param_data = param.data[expert_id, :half_out] - weight = loaded_weight.chunk(world_size, dim=0)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) elif shard_id == 'up': param_data = param.data[expert_id, half_out:] - weight = loaded_weight.chunk(world_size, dim=0)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) elif shard_id == 'down': param_data = param.data[expert_id] loaded_weight = loaded_weight.to(param_data.device) - weight = loaded_weight.chunk(world_size, dim=1)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1) else: raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(weight) @@ -558,7 +603,7 @@ def __init__(self, expert_list = self.impl.ep_expert_list(self.ep_size, rank) num_experts = len(expert_list) else: - hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size) expert_list = None self.expert_list = expert_list @@ -592,6 +637,14 @@ def __init__(self, self.device = device self.act_func = act_func + @staticmethod + def _update_args(hidden_dim: int, ffn_dim: int, align: int): + """Update args.""" + world_size, rank = get_tp_world_rank('moe') + split_size = _split_size(ffn_dim, world_size, align) + ffn_dim = split_size[rank] + return hidden_dim, ffn_dim + def init_tp_args(self, all_reduce: bool, enable_ep: bool): """Init tp args.""" tp, tp_rank = get_tp_world_rank('moe') From 50138365690d08279bff0905ab83398f2ecdfff8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Sep 2025 18:14:25 +0800 Subject: [PATCH 12/15] moe microbatch pipeline --- lmdeploy/pytorch/distributed.py | 11 ++- lmdeploy/pytorch/nn/moe.py | 156 +++++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 1269e6a36b..8cf6a15a28 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -7,7 +7,7 @@ import torch from torch import distributed as dist -from torch.distributed import ProcessGroup, ReduceOp # noqa: F401 +from torch.distributed import ProcessGroup, ReduceOp, Work # noqa: F401 from .config import DistConfig, TPMode @@ -433,12 +433,17 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group='tp', async_op=Fal return dist.reduce_scatter(output, input_list, op=op, group=group, async_op=async_op) -def gather_by_tp_sizes(x: torch.Tensor, tp_sizes: List[int], group: Optional[dist.ProcessGroup] = None): +def gather_by_tp_sizes(x: torch.Tensor, + tp_sizes: List[int], + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False): """Gather input.""" shape = (*x.shape[:-2], sum(tp_sizes), *x.shape[-1:]) new_x = x.new_empty(shape) split_new_x = list(new_x.split(tp_sizes, -2)) - dist.all_gather(split_new_x, x, group=group) + handle = dist.all_gather(split_new_x, x, group=group, async_op=async_op) + if async_op: + return new_x, handle return new_x diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 11d7644df8..70451e5ee0 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -64,6 +64,103 @@ def _split_size(size: int, world_size: int, align: int): return split_size +class MoEForwardDPTP: + + def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): + """MoE forward dp tp.""" + self.gemm_func = gemm_func + self.dist_ctx = get_dist_manager().current_context() + self.dist_config = self.dist_ctx.dist_config + self.tp = self.dist_config.moe_tp + self.attn_tp = self.dist_config.attn_tp + + tp_group = self.dist_ctx.moe_tp_group + self.rank = tp_group.rank + self.gather_rank = self.rank // self.attn_tp + self.gather_group = tp_group.gpu_gather_group + self.tp_group = tp_group.gpu_group + self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 + + def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + tp_sizes: List[int]): + """All gather.""" + hidden_states, _ = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) + topk_weights, _ = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True) + topk_ids, handle = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True) + return hidden_states, topk_weights, topk_ids, handle + + def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): + """Reduce scatter.""" + hidden_states_list = list(hidden_states.split(tp_sizes, -2)) + hidden_states_list[self.gather_rank] = out_states + hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] + handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) + return out_states, handle + + def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + output_states: torch.Tensor, tp_sizes: List[int], handle: dist.Work): + """Gemm and reduce scatter.""" + handle.wait() + cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids) + cur_out_states = cur_out.split(tp_sizes, dim=0)[self.gather_rank] + output_states.copy_(cur_out_states) + return self.reduce_scatter(cur_out, output_states, tp_sizes) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor): + """forward.""" + + def __slice_tensor(tensor: torch.Tensor, slice_size: int): + """Slice tensor.""" + cur_tensor = tensor[:slice_size] + tensor = tensor[slice_size:] + return cur_tensor, tensor + + def __slice_and_gather(): + """Slice and gather.""" + nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states + cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) + tp_sizes -= cur_tp_sizes + cur_tp_sizes = cur_tp_sizes.tolist() + + slice_size = cur_tp_sizes[self.gather_rank] + cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) + cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size) + cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size) + cur_output, output_states = __slice_tensor(output_states, slice_size) + + # all gather + cur_hidden_states, cur_topk_weights, cur_topk_ids, handle = self.all_gather( + cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes) + return dict(hidden_states=cur_hidden_states, + topk_weights=cur_topk_weights, + topk_ids=cur_topk_ids, + output_states=cur_output, + handle=handle, + tp_sizes=cur_tp_sizes) + + step_ctx = get_step_ctx_manager().current_context() + tp_sizes = step_ctx.dp_meta.moe_tp_sizes + tp_sizes = torch.tensor(tp_sizes) + max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) + + output_states = torch.empty_like(hidden_states) + return_states = output_states + + # pre + cur_inputs = __slice_and_gather() + + # main loop + while tp_sizes.sum() > 0: + next_inputs = __slice_and_gather() + self._gemm_and_reduce_scatter(**cur_inputs) + cur_inputs = next_inputs + + # post + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + handle.wait() + return return_states + + class LinearWeights(nn.Module): """Fused moe linear weights.""" @@ -279,13 +376,43 @@ def init_tp_args(self, all_reduce: bool, enable_ep: bool): self.tp_group = dist_ctx.moe_tp_group.gpu_group self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def __gemm_func(hidden_states, topk_weights, topk_ids): + return self.gemm( + dict( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_idx=topk_ids, + moe_type=MoeType.Default, + ))['hidden_states'] + + self.forward_dptp = MoEForwardDPTP(__gemm_func) + def update_weights(self): """Update weights.""" gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight) self.gate_up.update_weight(gate_up_weights) self.down.update_weight(down_weights) - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): + def gemm(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Gemm.""" + hidden_states = inputs['hidden_states'] + topk_weights = inputs['topk_weights'] + topk_ids = inputs['topk_idx'] + + ret = self.impl.forward(hidden_states, + topk_weights, + topk_ids, + self.gate_up.weight, + self.down.weight, + self.gate_up.bias, + self.down.bias, + self.expert_list, + act_func=self.act_func) + return dict(hidden_states=ret) + + def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, topk_weights, topk_ids, @@ -304,6 +431,12 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ ret = _moe_reduce(ret, rank=self.tp_rank, tp_mode=self.tp_mode, group=self.tp_group) return ret + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): + """forward.""" + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self.forward_dptp.forward(hidden_states, topk_weights, topk_ids) + return self.forward_default(hidden_states, topk_weights, topk_ids) + class LinearWeightsW8A8(LinearWeights): """Fused moe linear w8a8 weights.""" @@ -662,6 +795,18 @@ def init_tp_args(self, all_reduce: bool, enable_ep: bool): self.all_reduce = all_reduce self.tp_group = dist_ctx.moe_tp_group.gpu_group self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def __gemm_func(hidden_states, topk_weights, topk_ids): + return self.gemm( + dict( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_idx=topk_ids, + moe_type=MoeType.Default, + ))['hidden_states'] + + self.forward_dptp = MoEForwardDPTP(__gemm_func) def update_weights(self): """Update weights.""" @@ -671,7 +816,7 @@ def update_weights(self): self.gate_up.update_weight(gate_up_weights, gate_up_scale) self.down.update_weight(down_weights, down_scale) - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): state = { 'hidden_states': hidden_states, 'topk_idx': topk_idx, @@ -683,6 +828,13 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ out_state = self.combine(gemm_state) return out_state['hidden_states'] + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx) + else: + return self.forward_default(hidden_states, topk_weights, topk_idx) + def before_dispatch(self, state: Dict): moe_type = state['moe_type'] if moe_type == MoeType.DSAsyncPrefill: From e75fef62a9af0c16357a92d1557c427faa3bd9bf Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Sep 2025 19:08:14 +0800 Subject: [PATCH 13/15] fix --- lmdeploy/pytorch/nn/linear/base.py | 124 +++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 8 deletions(-) diff --git a/lmdeploy/pytorch/nn/linear/base.py b/lmdeploy/pytorch/nn/linear/base.py index c69f8a580a..0229e4b67d 100644 --- a/lmdeploy/pytorch/nn/linear/base.py +++ b/lmdeploy/pytorch/nn/linear/base.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.distributed as dist @@ -13,6 +13,92 @@ from .utils import update_tp_args +class LinearForwardDPTP: + + def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): + """Linear forward dp tp.""" + self.gemm_func = gemm_func + self.dist_ctx = get_dist_manager().current_context() + self.dist_config = self.dist_ctx.dist_config + self.tp = self.dist_config.mlp_tp + self.attn_tp = self.dist_config.attn_tp + + tp_group = self.dist_ctx.mlp_tp_group + self.rank = tp_group.rank + self.gather_rank = self.rank // self.attn_tp + self.gather_group = tp_group.gpu_gather_group + self.tp_group = tp_group.gpu_group + self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 + + def all_gather(self, hidden_states: torch.Tensor, tp_sizes: List[int]): + """All gather.""" + hidden_states, handle = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) + return hidden_states, handle + + def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): + """Reduce scatter.""" + hidden_states_list = list(hidden_states.split(tp_sizes, -2)) + hidden_states_list[self.gather_rank] = out_states + hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] + handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) + return out_states, handle + + def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, output_states: torch.Tensor, tp_sizes: List[int], + handle: dist.Work): + """Gemm and reduce scatter.""" + handle.wait() + cur_out = self.gemm_func(hidden_states) + cur_out_states = cur_out.split(tp_sizes, dim=0)[self.gather_rank] + output_states.copy_(cur_out_states) + return self.reduce_scatter(cur_out, output_states, tp_sizes) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + + def __slice_tensor(tensor: torch.Tensor, slice_size: int): + """Slice tensor.""" + cur_tensor = tensor[:slice_size] + tensor = tensor[slice_size:] + return cur_tensor, tensor + + def __slice_and_gather(): + """Slice and gather.""" + nonlocal hidden_states, tp_sizes, output_states + cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) + tp_sizes -= cur_tp_sizes + cur_tp_sizes = cur_tp_sizes.tolist() + + slice_size = cur_tp_sizes[self.gather_rank] + cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) + cur_output, output_states = __slice_tensor(output_states, slice_size) + + # all gather + cur_hidden_states, handle = self.all_gather(cur_hidden_states, cur_tp_sizes) + return dict(hidden_states=cur_hidden_states, output_states=cur_output, handle=handle, tp_sizes=cur_tp_sizes) + + step_ctx = get_step_ctx_manager().current_context() + tp_sizes = step_ctx.dp_meta.moe_tp_sizes + tp_sizes = torch.tensor(tp_sizes) + max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) + + output_states = torch.empty_like(hidden_states) + return_states = output_states + + # pre + cur_inputs = __slice_and_gather() + + # main loop + while tp_sizes.sum() > 0: + next_inputs = __slice_and_gather() + self._gemm_and_reduce_scatter(**cur_inputs) + cur_inputs = next_inputs + + # post + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + handle.wait() + return return_states + + class LinearBase(nn.Module): """Base class for linear layers.""" @@ -65,6 +151,17 @@ def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: self.tp_group = None self.gather_group = None + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def _gemm_func(self, x): + out = self._forward_default(x, False, None) + + for lora_adapter in self.lora_adapters.values(): + out = lora_adapter(x, out) + return out + + self.linear_dptp_forward = LinearForwardDPTP(_gemm_func) + self._tp_args_initialized = True def get_tp_world_rank(self): @@ -93,13 +190,14 @@ def _forward_lora(self, x, tp_sizes: List[int]): dist.all_reduce(out, group=self.tp_group) return out - def forward(self, x): - """Forward of linear layer.""" - tp_sizes = None - if self.dp_gather or (self.all_reduce and self.tp_mode == TPMode.DP_TP): - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - tp_sizes = dp_meta.tp_sizes + def _forward_dp_tp(self, x): + """Forward dp_tp.""" + if self.dp_gather and self.all_reduce: + return self.linear_dptp_forward.forward(x) + + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_sizes = dp_meta.tp_sizes if self.dp_gather: x = gather_by_tp_sizes(x, tp_sizes, group=self.gather_group) @@ -108,3 +206,13 @@ def forward(self, x): return self._forward_default(x, self.all_reduce, tp_sizes) else: return self._forward_lora(x, tp_sizes) + + def forward(self, x): + """Forward of linear layer.""" + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self._forward_dp_tp(x) + + if len(self.lora_adapters) == 0: + return self._forward_default(x, self.all_reduce, None) + else: + return self._forward_lora(x) From cc96976f0e46c7bf01334b3f0027e3f40e6c97ee Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Sep 2025 19:37:54 +0800 Subject: [PATCH 14/15] fix --- lmdeploy/pytorch/nn/moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 70451e5ee0..5012969996 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -603,13 +603,14 @@ def __init__(self, device=device) weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) self.register_parameter('weight_scale_inv', weight_scale_inv) - self.weight._base_weight_loader = self.weight_loader_tp - self.weight.weight_loader = self.weight_loader_with_quant if self.ep: + self.weight._base_weight_loader = self.weight.weight_loader self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep else: + self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8 self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp + self.weight.weight_loader = self.weight_loader_with_quant def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor): """Update weight.""" @@ -633,7 +634,8 @@ def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank split_size = _split_size(weight.size(dim), world_size, align) return weight.split(split_size, dim=dim)[rank] - def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): + def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): """Weight loader.""" world_size, rank = get_tp_world_rank('moe') if shard_id == 'gate': From f45441468e77f5578c8a52cdea21ec11a59c25cc Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 15 Oct 2025 19:09:45 +0800 Subject: [PATCH 15/15] patch deep-gemm --- lmdeploy/pytorch/backends/cuda/moe.py | 227 +++++------------- .../pytorch/third_party/deep_gemm/__init__.py | 40 +++ 2 files changed, 99 insertions(+), 168 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 59831d8146..d0e9e663f2 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, List, Optional +import contextlib +from typing import Callable, List import torch import torch.distributed as dist -from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency, TokenDispatcherBuilder from lmdeploy.pytorch.distributed import get_dist_manager from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 @@ -412,167 +412,39 @@ def forward( return down_output -class FusedMoENormal: +@contextlib.contextmanager +def monk_deep_gemm(): + from dlblas.kernels.fused_moe_v3 import use_deep_gemm + if use_deep_gemm: + yield + return - def __init__(self, - ep_size: int, - ep_group: dist.ProcessGroup, - num_experts: int, - hidden_dim: int, - block_size: int = 128, - out_dtype: torch.dtype = torch.bfloat16): - self.experts = DeepEPExpertsGroupedGEMM(num_experts, ep_size, [block_size, block_size]) - self.token_dispatcher = TokenDispatcherBuilder.build( - group=ep_group, - num_experts=num_experts, - num_local_experts=num_experts // ep_size, - hidden_size=hidden_dim, - params_dtype=out_dtype, - ) - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - up_weights: torch.Tensor, - up_scale: torch.Tensor, - down_weights: torch.Tensor, - down_scale: torch.Tensor, - expert_list: List[int] = None): - """forward.""" - recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = self.token_dispatcher.dispatch( - hidden_states, - topk_ids, - topk_weights, - expert_list, - ) - out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, up_weights, up_scale, down_weights, - down_scale) - out_states = self.token_dispatcher.combine(out_states) - return out_states - - def capture(self): - return self.token_dispatcher.buffer_normal.capture() - - def wait(self, event): - self.token_dispatcher.release() - event.current_stream_wait() - - def dispatch_async(self, - x: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - num_experts: Optional[int] = None, - previous_event=None, - async_finish=True): - return self.token_dispatcher.dispatch_normal_async(x, topk_idx, topk_weights, num_experts, previous_event, - async_finish) - - def combine_async(self, x: torch.Tensor, handle: tuple, previous_event=None, async_finish=True): - return self.token_dispatcher.combine_normal_async(x, handle, previous_event, async_finish) - - def release(self): - return self.token_dispatcher.release() - - def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale): - ( - hidden_states, - recv_hidden_states_shape, - dispatched_routing_map, - topk_weights, - reversed_mapping_for_combine, - ) = self.token_dispatcher.get_permuted_hidden_states_by_experts(state['recv_hidden_states'], - state['recv_topk_idx'], - state['recv_topk_weights'], - state['num_experts']) - tokens_per_expert = torch.tensor( - state['recv_tokens_per_expert'], - device=hidden_states.device, - dtype=torch.int64, - ) - hidden_states = self.experts.forward(hidden_states, tokens_per_expert, up_weight, up_scale, down_weight, - down_scale) - hidden_states = self.token_dispatcher.get_restored_hidden_states_by_experts(hidden_states, - reversed_mapping_for_combine, - recv_hidden_states_shape, - dispatched_routing_map, - topk_weights) - return hidden_states + # patch deep_gemm + import deep_gemm + import dlblas + from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm + func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) + func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) + deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked -class FusedMoELowLatency: + # patch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = True + dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ + patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous + yield - def __init__(self, - ep_size: int, - ep_group: dist.ProcessGroup, - num_experts: int, - hidden_dim: int, - block_size: int = 128, - out_dtype: torch.dtype = torch.bfloat16): - self.num_experts = num_experts - self.experts = DeepEPExpertsDeepGEMM(num_experts, ep_size, block_size, out_dtype) - self.token_dispatcher = DeepEPTokenDispatcherLowLatency( - group=ep_group, - num_experts=num_experts, - num_local_experts=num_experts // ep_size, - hidden_size=hidden_dim, - params_dtype=out_dtype, - ) - - def forward(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, - up_weights: torch.Tensor, - up_scale: torch.Tensor, - down_weights: torch.Tensor, - down_scale: torch.Tensor, - expert_list: List[int] = None): - """forward.""" - recv_hidden_states, topk_idx, topk_weights, masked_m, expected_m = self.token_dispatcher.dispatch( - hidden_states, - topk_ids, - topk_weights, - self.num_experts, - ) - out_states = self.experts.forward(recv_hidden_states, up_weights, up_scale, down_weights, down_scale, masked_m, - expected_m) - out_states = self.token_dispatcher.combine(out_states, topk_idx, topk_weights) - return out_states - - def wait(self, event): - event.current_stream_wait() - - def dispatch_async( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - num_experts: Optional[int] = None, - use_fp8: bool = True, - async_finish: bool = True, - ): - return self.token_dispatcher.dispatch_async(hidden_states, topk_idx, num_experts, use_fp8, async_finish) - - def combine_async( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - handle: tuple, - async_finish: bool, - ): - return self.token_dispatcher.combine_async(hidden_states, topk_idx, topk_weights, handle, async_finish) + # unpatch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = False - def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale): - recv_hidden_states = state['recv_hidden_states'] - recv_expert_count = state['recv_expert_count'] - hidden_shape = state['raw_hidden_shape'] - topk_idx = state['topk_idx'] - expected_m = (hidden_shape[0] * self.token_dispatcher.buffer_low_latency.group_size * topk_idx.shape[1] + - self.token_dispatcher.num_experts) // self.token_dispatcher.num_experts - hidden_states = self.experts.forward(recv_hidden_states, up_weight, up_scale, down_weight, down_scale, - recv_expert_count, expected_m) - return hidden_states + # unpatch deep_gemm + if func0_ is not None: + deep_gemm.get_col_major_tma_aligned_tensor = func0_ + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ + else: + del deep_gemm.get_col_major_tma_aligned_tensor + del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): @@ -690,16 +562,35 @@ def do_renormalize(self, topk_weights): def fusedmoe_build(self, low_latency_mode: bool = False): from dlblas.layers.moe.ep_moe import build_deepep_moe - return build_deepep_moe(low_latency_mode, - self.ep_size, - self.ep_group, - self.num_experts, - self.hidden_dim, - self.block_size, - self.top_k, - self.out_dtype, - layer_idx=self.layer_idx, - chunk_size=16 * 1024) + deepep_moe = build_deepep_moe(low_latency_mode, + self.ep_size, + self.ep_group, + self.num_experts, + self.hidden_dim, + self.block_size, + self.top_k, + self.out_dtype, + layer_idx=self.layer_idx, + chunk_size=16 * 1024) + + # patch forward + _origin_forward = deepep_moe.forward + _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward + + def _patched_forward(*args, **kwargs): + with monk_deep_gemm(): + out = _origin_forward(*args, **kwargs) + return out + + def _patched_fusedmoe_forward(*args, **kwargs): + with monk_deep_gemm(): + out = _origin_fusedmoe_forward(*args, **kwargs) + return out + + deepep_moe.forward = _patched_forward + deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward + + return deepep_moe class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): diff --git a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py index 1e734c4073..369862e60e 100644 --- a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py +++ b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py @@ -42,3 +42,43 @@ def fp8_gemm_nt(a, b, d, c, recipe=None, compiled_dim='nk', disable_ue8m0_cast=F N, _ = b[0].shape with _log_jit_build(M, N, K): gemm_fp8_fp8_bf16_nt(a, b, d) + + +try: + from deep_gemm import m_grouped_fp8_gemm_nt_contiguous +except Exception: + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous + + def m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, recipe=None, compiled_dims='nk', disable_ue8m0_cast=False): + assert recipe is None + assert compiled_dims == 'nk' + assert disable_ue8m0_cast is False + return m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(a, b, d, m_indices) + + +try: + from deep_gemm import m_grouped_fp8_gemm_nt_masked +except Exception: + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked + + def m_grouped_fp8_gemm_nt_masked(a, + b, + d, + masked_m, + expected_m, + recipe=None, + compiled_dims='nk', + disable_ue8m0_cast=False): + assert recipe is None + assert compiled_dims == 'nk' + assert disable_ue8m0_cast is False + return m_grouped_gemm_fp8_fp8_bf16_nt_masked(a, b, d, masked_m, expected_m) + + +try: + from deep_gemm import get_mn_major_tma_aligned_tensor +except Exception: + from deep_gemm import get_col_major_tma_aligned_tensor + + def get_mn_major_tma_aligned_tensor(x): + return get_col_major_tma_aligned_tensor(x)