diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 73c02e2914..4441e98c99 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -289,6 +289,9 @@ class PytorchEngineConfig: session_len (int): Max session length. Default None. max_batch_size (int): Max batch size. If it is not specified, the engine will automatically set it according to the device + attn_tp_size (int): tp size for attention, only works for dp>1 + mlp_tp_size (int): tp size for mlp, only works for dp>1 + moe_tp_size (int): tp size for moe, only works for dp>1 cache_max_entry_count (float): the percentage of gpu memory occupied by the k/v cache. For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory @@ -350,6 +353,9 @@ class PytorchEngineConfig: ep: int = 1 session_len: int = None max_batch_size: int = None + attn_tp_size: int = None + mlp_tp_size: int = None + moe_tp_size: int = None cache_max_entry_count: float = 0.8 prefill_interval: int = 16 block_size: int = 64 diff --git a/lmdeploy/pytorch/backends/awq_modules.py b/lmdeploy/pytorch/backends/awq_modules.py index bfd5372ba3..1a9815c423 100644 --- a/lmdeploy/pytorch/backends/awq_modules.py +++ b/lmdeploy/pytorch/backends/awq_modules.py @@ -17,7 +17,12 @@ def update_weights(self, return qweight, scales, qzeros, bias @abstractmethod - def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False): + def forward(self, + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py index 6750654a79..b30bb56fae 100644 --- a/lmdeploy/pytorch/backends/blockedf8_modules.py +++ b/lmdeploy/pytorch/backends/blockedf8_modules.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch +import torch.distributed as dist class LinearBlockedF8Impl(ABC): @@ -19,6 +20,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index b241c384b2..b896b0579e 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -90,7 +90,7 @@ def __init__( self.flash_attention_fwd = flash_attention_fwd # for alibi attention - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') self.alibi_head_offset = self.num_heads * rank self.alibi_num_heads = self.num_heads * world_size self.block_sparse_size = block_sparse_size diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py index be23e06e79..01516a8aca 100644 --- a/lmdeploy/pytorch/backends/cuda/awq_modules.py +++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py @@ -55,12 +55,13 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out_features = scales.size(1) out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py index 25d69806b3..e25e9c52e7 100644 --- a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py +++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py @@ -13,15 +13,6 @@ logger = get_logger('lmdeploy') -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]): - """Reduce scatter.""" - outs = out.split(tp_sizes, -2) - out = outs[rank] - outs = list(outs) - dist.reduce_scatter(out, outs) - return out - - class TritonLinearBlockedF8Impl(LinearBlockedF8Impl): """Triton linear blocked f8 implementation.""" @@ -37,6 +28,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" @@ -52,7 +44,7 @@ def forward(self, if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: dist.all_reduce(out) return out @@ -117,6 +109,7 @@ def forward(self, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: Optional[dist.ProcessGroup] = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" @@ -128,12 +121,11 @@ def forward(self, out = out[:x.size(0)] if bias is not None: out += bias + out = out.unflatten(0, x_shape[:-1]) if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: - dist.all_reduce(out) - - out = out.unflatten(0, x_shape[:-1]) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..536ab65e00 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -258,7 +258,7 @@ def update_inputs(self, inputs): meta = self.get_meta() padding_batch_size = meta.padding_batch_size tp_size = self._get_capture_tokens(padding_batch_size) - dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes) + dp_meta.sync_tp_size(tp_size) return inputs def get_capture_batch_sizes(self) -> List[int]: diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index a1394f28bd..59831d8146 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -619,6 +619,45 @@ def ep_expert_list(self, world_size: int, rank: int): else: return super().ep_expert_list(world_size=world_size, rank=rank) + def _split_inputs_by_attn_tp( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + ): + """Split input by attn tp.""" + dist_ctx = get_dist_manager().current_context() + attn_tp = dist_ctx.dist_config.attn_tp + attn_rank = dist_ctx.attn_tp_group.rank + num_states = hidden_states.size(0) + + if attn_tp == 1 or attn_tp > num_states: + return hidden_states, topk_weights, topk_ids, None + + # split size + base = num_states // attn_tp + remain = num_states % attn_tp + split_size = [base + 1] * remain + [base] * (attn_tp - remain) + + # split inputs + hidden_states = torch.split(hidden_states, split_size, dim=0)[attn_rank] + topk_weights = torch.split(topk_weights, split_size, dim=0)[attn_rank] + topk_ids = torch.split(topk_ids, split_size, dim=0)[attn_rank] + + return hidden_states, topk_weights, topk_ids, split_size + + def _gather_outputs_by_attn_tp(self, out_states: torch.Tensor, split_size: List[int]): + """Gather output by attn tp.""" + if split_size is None: + return out_states + + dist_ctx = get_dist_manager().current_context() + gpu_group = dist_ctx.attn_tp_group.gpu_group + new_out_states = out_states.new_empty((sum(split_size), out_states.shape[1])) + new_out_states_list = list(new_out_states.split(split_size, dim=0)) + dist.all_gather(new_out_states_list, out_states, group=gpu_group) + return new_out_states + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -633,12 +672,17 @@ def forward(self, act_func: Callable = None, **kwargs): """forward.""" + hidden_states, topk_weights, topk_ids, split_size = self._split_inputs_by_attn_tp( + hidden_states, topk_weights, topk_ids) + topk_weights = self.do_renormalize(topk_weights) step_ctx = get_step_ctx_manager().current_context() low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm moe = self.fusedmoe_build(low_latency_mode) out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, down_scale, expert_list) + + out_states = self._gather_outputs_by_attn_tp(out_states, split_size) return out_states def do_renormalize(self, topk_weights): diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py index 9f2576c3f8..dc61787731 100644 --- a/lmdeploy/pytorch/backends/cuda/qmodules.py +++ b/lmdeploy/pytorch/backends/cuda/qmodules.py @@ -63,7 +63,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" if isinstance(x, torch.Tensor): input_quant, input_scale = per_token_quant_int8(x, 1e-7, quant_dtype=self.quant_dtype) @@ -79,7 +80,7 @@ def forward(self, bias=bias) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/default/awq_modules.py b/lmdeploy/pytorch/backends/default/awq_modules.py index 5c80e0e327..d2253920fa 100644 --- a/lmdeploy/pytorch/backends/default/awq_modules.py +++ b/lmdeploy/pytorch/backends/default/awq_modules.py @@ -62,7 +62,8 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out_shape = x.shape[:-1] + (self.out_features, ) input_dtype = x.dtype @@ -77,7 +78,7 @@ def forward(self, if input_dtype != torch.float16: out = out.to(dtype=input_dtype) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/default/linear.py b/lmdeploy/pytorch/backends/default/linear.py index b223b498ae..f766123fff 100644 --- a/lmdeploy/pytorch/backends/default/linear.py +++ b/lmdeploy/pytorch/backends/default/linear.py @@ -2,26 +2,12 @@ from typing import List, Optional import torch +import torch.distributed as dist import torch.nn.functional as F -import lmdeploy.pytorch.distributed as dist - from ..linear import LinearBuilder, LinearImpl -def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]): - """Reduce scatter.""" - out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() - outs = out.split(tp_sizes, 0) - out = outs[rank] - outs = list(outs) - dist.reduce_scatter(out, outs) - out = out.transpose(0, -2) - return out - - class DefaultLinearImpl(LinearImpl): """Linear implementation api.""" @@ -30,15 +16,17 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" out = F.linear(x, weight, bias) if all_reduce: if scatter_size is not None: - out = _reduce_scatter_input(out, rank, scatter_size) + from lmdeploy.pytorch.distributed import reduce_scatter_by_tp_sizes + out = reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group) else: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/awq_modules.py b/lmdeploy/pytorch/backends/dlinfer/awq_modules.py index 8dcf750478..1ec8bf0072 100644 --- a/lmdeploy/pytorch/backends/dlinfer/awq_modules.py +++ b/lmdeploy/pytorch/backends/dlinfer/awq_modules.py @@ -23,7 +23,8 @@ def forward(self, scales: torch.Tensor, qzeros: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" out = awq_linear(x, qweight, scales, qzeros, bias, all_reduce, self.group_size) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/linear.py b/lmdeploy/pytorch/backends/dlinfer/linear.py index 327eeae56b..ec682bba8b 100644 --- a/lmdeploy/pytorch/backends/dlinfer/linear.py +++ b/lmdeploy/pytorch/backends/dlinfer/linear.py @@ -3,8 +3,8 @@ from typing import List, Optional import torch +import torch.distributed as dist -import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.kernels.dlinfer import linear from ..linear import LinearBuilder, LinearImpl @@ -32,12 +32,13 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" out = linear(x, weight, bias, False) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py index bdd5259fd0..fe52dd5f35 100644 --- a/lmdeploy/pytorch/backends/dlinfer/qmodules.py +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -36,7 +36,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" if isinstance(x, torch.Tensor): input_quant, input_scale = dynamic_quant(x, self.quant_dtype) @@ -46,7 +47,7 @@ def forward(self, out = linear_w8a8(input_quant, weight, input_scale, scale, self.out_dtype, self.quant_dtype, bias) if all_reduce: - dist.all_reduce(out) + dist.all_reduce(out, group=group) return out diff --git a/lmdeploy/pytorch/backends/linear.py b/lmdeploy/pytorch/backends/linear.py index d577dcba54..740b4b7ecc 100644 --- a/lmdeploy/pytorch/backends/linear.py +++ b/lmdeploy/pytorch/backends/linear.py @@ -3,6 +3,7 @@ from typing import List, Optional import torch +import torch.distributed as dist class LinearImpl(ABC): @@ -18,6 +19,7 @@ def forward(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False, + group: dist.ProcessGroup = None, rank: int = 0, scatter_size: List[int] = None): """forward.""" diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py index 7cb485888b..7173fb5f34 100644 --- a/lmdeploy/pytorch/backends/qmodules.py +++ b/lmdeploy/pytorch/backends/qmodules.py @@ -47,7 +47,8 @@ def forward(self, weight: torch.Tensor, scale: torch.Tensor, bias: Optional[torch.Tensor] = None, - all_reduce: bool = False): + all_reduce: bool = False, + group: Optional[torch.distributed.ProcessGroup] = None): """forward.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/check_env/dist.py b/lmdeploy/pytorch/check_env/dist.py index e31e6c9161..eb6e48e882 100644 --- a/lmdeploy/pytorch/check_env/dist.py +++ b/lmdeploy/pytorch/check_env/dist.py @@ -44,9 +44,9 @@ def check(self): if self.device_type == 'cuda' and not is_dlblas_installed(): self.log_and_exit(mod_name='Dist', message='ep>1 requires install dlblas(https://github.com/DeepLink-org/dlBLAS).') - if self.dp % self.ep != 0: + if self.ep % self.dp != 0: self.log_and_exit(mod_name='Dist', - message=f'ep>1 requires dp % ep == 0. Get dp={self.dp} and ep={self.ep}.') + message=f'ep>1 requires ep % dp == 0. Get dp={self.dp} and ep={self.ep}.') elif self.dist_config.enable_eplb: self.log_and_exit(mod_name='Dist', message=f'Enable eplb requires ep > 1. Get ep={self.ep}.') diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..5fded92303 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -100,35 +100,105 @@ def __post_init__(self): self.enable_prefix_caching = False +class TPMode(enum.Enum): + """TP Mode.""" + DEFAULT = enum.auto() + DP_TP = enum.auto() + + @dataclass class DistConfig: dp: int = 1 - tp: int = 1 ep: int = 1 dp_rank: int = 0 enable_microbatch: bool = False enable_eplb: bool = False - world_size: int = None - attn_config: 'DistConfig' = None + world_size: int = 1 + + # tp + tp: int = 1 # default tp, equal to attn_tp + attn_tp: int = None # tp for attention + mlp_tp: int = None # tp for mlp + moe_tp: int = None # tp for moe + + # tp mode + mlp_tp_mode: TPMode = TPMode.DEFAULT + moe_tp_mode: TPMode = TPMode.DEFAULT def __post_init__(self): """Post init.""" assert self.dp_rank < self.dp assert self.dp >= 1 - if self.dp == 1: - world_size = max(self.tp, self.ep) - attn_config = self - else: - world_size = self.dp - attn_config = DistConfig(dp=1, tp=1, ep=1, dp_rank=0) + + dp = self.dp + tp = self.tp + ep = self.ep + + # ignore layer to for dp==1 + if dp == 1: + self.mlp_tp = None + self.attn_tp = None + self.moe_tp = None + + # mlp and moe tp + self.mlp_tp = self.mlp_tp or tp + self.moe_tp = self.moe_tp or (1 if ep > 1 else self.mlp_tp) + + # world_size + world_size = ep if ep > 1 else max(self.mlp_tp, self.moe_tp) self.world_size = world_size - self.attn_config = attn_config + assert (world_size >= dp and world_size % dp == 0), (f'world_size {world_size}, dp {dp}') + assert (world_size >= ep and world_size % ep == 0), (f'world_size {world_size}, ep {ep}') + assert (world_size >= self.mlp_tp + and world_size % self.mlp_tp == 0), (f'world_size {world_size}, mlp_tp {self.mlp_tp}') + assert (world_size >= self.moe_tp + and world_size % self.moe_tp == 0), (f'world_size {world_size}, moe_tp {self.moe_tp}') + + # attn tp + self.attn_tp = self.attn_tp or self.world_size // dp + self.tp = self.attn_tp + if self.mlp_tp > 1: + assert (self.mlp_tp >= self.attn_tp + and self.mlp_tp % self.attn_tp == 0), (f'mlp_tp {self.mlp_tp}, attn_tp {self.attn_tp}') + if self.moe_tp > 1: + assert (self.moe_tp >= self.attn_tp + and self.moe_tp % self.attn_tp == 0), (f'moe_tp {self.moe_tp}, attn_tp {self.attn_tp}') + assert (world_size >= self.attn_tp + and world_size % self.attn_tp == 0), (f'world_size {world_size}, attn_tp {self.attn_tp}') + + # tp mode + self.mlp_tp_mode = TPMode.DEFAULT if (self.mlp_tp in [1, self.attn_tp]) else TPMode.DP_TP + self.moe_tp_mode = TPMode.DEFAULT if (self.moe_tp in [1, self.attn_tp]) else TPMode.DP_TP + + def get_tp_by_layer(self, layer_type: str): + """Get tp by layer type.""" + if layer_type == 'attn': + return self.attn_tp, TPMode.DEFAULT + elif layer_type == 'mlp': + return self.mlp_tp, self.mlp_tp_mode + elif layer_type == 'moe': + return self.moe_tp, self.moe_tp_mode + elif layer_type is None: + # for some layer that we don't need tp + return 1, TPMode.DEFAULT + else: + raise ValueError(f'Unknown layer type: {layer_type}') - def need_dummy_batch(self): - """Need dummy batch.""" - if self.dp == 1: - return False - return self.tp > 1 or self.ep > 1 + @classmethod + def from_engine_config(cls, engine_config: PytorchEngineConfig): + """From engine config.""" + dist_config = cls( + dp=engine_config.dp, + ep=engine_config.ep, + dp_rank=engine_config.dp_rank, + enable_microbatch=engine_config.enable_microbatch, + enable_eplb=engine_config.enable_eplb, + tp=engine_config.tp, + attn_tp=engine_config.attn_tp_size, + mlp_tp=engine_config.mlp_tp_size, + moe_tp=engine_config.moe_tp_size, + ) + return dist_config def _override_hf_config_dict(hf_config: dict, key: str, hf_overrides): @@ -261,10 +331,7 @@ def from_hf_config(cls, from lmdeploy.pytorch.configurations import AutoModelConfigBuilder if dist_config is None: dist_config = DistConfig() - if dist_config.dp == 1: - tp = dist_config.tp - else: - tp = 1 + tp = dist_config.attn_tp model_config = AutoModelConfigBuilder.build(hf_config, model_path, tp=tp) diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 46b89a1b04..8cf6a15a28 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -2,125 +2,272 @@ import threading from contextlib import contextmanager from dataclasses import dataclass -from typing import List +from datetime import timedelta +from typing import List, Optional +import torch from torch import distributed as dist -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp, Work # noqa: F401 -from .config import DistConfig +from .config import DistConfig, TPMode + + +@dataclass +class DistGroup: + """Distributed group.""" + rank: int = 0 + cpu_group: dist.ProcessGroup = None + gpu_group: dist.ProcessGroup = None + cpu_groups: List[dist.ProcessGroup] = None + gpu_groups: List[dist.ProcessGroup] = None + gpu_gather_group: dist.ProcessGroup = None + + def close(self): + """Close groups.""" + if not dist.is_initialized(): + return + if self.cpu_groups is not None: + for group in self.cpu_groups: + dist.destroy_process_group(group) + self.cpu_groups = None + if self.gpu_groups is not None: + for group in self.gpu_groups: + dist.destroy_process_group(group) + self.gpu_groups = None + + +def _build_tp_group_impl(tp: int, + rank: int, + world_size: int, + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl', + attn_tp: int = 1, + tp_mode: TPMode = TPMode.DEFAULT): + """Build tp group.""" + assert tp > 1 + tp_rank = rank % tp + tp_group_id = rank // tp + gather_group_id = (rank - tp_group_id * tp) % attn_tp + ranks = range(world_size) + tp_gpu_groups = [] + tp_cpu_groups = [] + gather_groups = [] + for start in range(0, world_size, tp): + tp_ranks = ranks[start:start + tp] + group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend) + tp_gpu_groups.append(group) + cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend) + tp_cpu_groups.append(cpu_group) + + # create gather group + if tp_mode == TPMode.DP_TP and attn_tp != tp: + for g_start in range(start, start + attn_tp): + g_ranks = ranks[g_start:(g_start + tp):attn_tp] + gather_group = dist.new_group(ranks=g_ranks, timeout=timeout, backend=ccl_backend) + gather_groups.append(gather_group) + tp_gpu_group = tp_gpu_groups[tp_group_id] + tp_cpu_group = tp_cpu_groups[tp_group_id] + + if tp_mode == TPMode.DP_TP: + if attn_tp == tp: + gather_group = tp_gpu_group + else: + gather_group = gather_groups[gather_group_id] + else: + gather_group = None + return DistGroup( + rank=tp_rank, + cpu_group=tp_cpu_group, + gpu_group=tp_gpu_group, + cpu_groups=tp_cpu_groups, + gpu_groups=tp_gpu_groups, + gpu_gather_group=gather_group, + ) + + +def _build_attn_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build attention tp group.""" + dist_config = context.dist_config + tp = dist_config.attn_tp + # skip if tp == 1 + if tp == 1: + context.attn_tp_group = DistGroup(rank=0) + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + attn_tp=tp, + tp_mode=TPMode.DEFAULT, + ) + context.attn_tp_group = dist_group + + +def _build_mlp_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build mlp tp group.""" + dist_config = context.dist_config + tp = dist_config.mlp_tp + # skip if tp == 1 + if tp == 1: + context.mlp_tp_group = DistGroup(rank=0) + return + + # reuse attn tp group + if tp == dist_config.attn_tp: + context.mlp_tp_group = context.attn_tp_group + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + attn_tp=dist_config.attn_tp, + tp_mode=dist_config.mlp_tp_mode, + ) + context.mlp_tp_group = dist_group + + +def _build_moe_tp_group(context: 'DistContext', + timeout: timedelta, + cpu_backend: str = 'gloo', + ccl_backend: str = 'nccl'): + """Build moe tp group.""" + dist_config = context.dist_config + tp = dist_config.moe_tp + # skip if tp == 1 + if tp == 1: + context.moe_tp_group = DistGroup(rank=0) + return + + # reuse attn tp group + if tp == dist_config.attn_tp: + context.moe_tp_group = context.attn_tp_group + return + + # reuse mlp tp group + if tp == dist_config.mlp_tp: + context.moe_tp_group = context.mlp_tp_group + return + + dist_group = _build_tp_group_impl( + tp, + context.rank, + dist_config.world_size, + timeout=timeout, + cpu_backend=cpu_backend, + ccl_backend=ccl_backend, + attn_tp=dist_config.attn_tp, + tp_mode=dist_config.moe_tp_mode, + ) + context.moe_tp_group = dist_group + + +def _build_tp_group(context: 'DistContext', timeout: timedelta, cpu_backend: str = 'gloo', ccl_backend: str = 'nccl'): + """Build tp group.""" + _build_attn_tp_group(context, timeout, cpu_backend, ccl_backend) + _build_mlp_tp_group(context, timeout, cpu_backend, ccl_backend) + _build_moe_tp_group(context, timeout, cpu_backend, ccl_backend) + context.tp_group = context.attn_tp_group @dataclass class DistContext: rank: int = 0 - world_size: int = 1 - tp: int = 1 - dp: int = 1 - ep: int = 1 - tp_rank: int = 0 dp_rank: int = 0 ep_rank: int = 0 - world_cpu_group: dist.ProcessGroup = None - tp_cpu_group: dist.ProcessGroup = None - tp_gpu_group: dist.ProcessGroup = None - tp_gpu_groups: List[dist.ProcessGroup] = None - dp_cpu_group: dist.ProcessGroup = None - dp_gpu_group: dist.ProcessGroup = None + + tp_group: DistGroup = None + attn_tp_group: DistGroup = None + mlp_tp_group: DistGroup = None + moe_tp_group: DistGroup = None + ep_gpu_group: dist.ProcessGroup = None ep_gpu_groups: List[dist.ProcessGroup] = None dist_config: DistConfig = None + @classmethod + def _build_ep_group(cls, context: 'DistContext', timeout: timedelta, ccl_backend: str = 'nccl'): + """Build ep group.""" + dist_config = context.dist_config + ep = dist_config.ep + if ep <= 1: + return + + dp_rank = context.dp_rank + world_size = dist_config.world_size + ep_rank = context.rank % ep + ep_group_id = dp_rank // ep + ranks = range(world_size) + ep_gpu_groups = [] + for start in range(0, world_size, ep): + ep_ranks = ranks[start:start + ep] + group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend) + ep_gpu_groups.append(group) + ep_gpu_group = ep_gpu_groups[ep_group_id] + + context.ep_rank = ep_rank + context.ep_gpu_group = ep_gpu_group + context.ep_gpu_groups = ep_gpu_groups + @classmethod def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'): """Build dist context.""" - from datetime import timedelta timeout = timedelta(days=35600) cpu_backend = 'gloo' if dist_config is None: dist_config = DistConfig() - tp = dist_config.tp - dp = dist_config.dp - ep = dist_config.ep - world_size = dist_config.world_size - dp_rank = dist_config.dp_rank + dp_rank = dist_config.dp_rank + world_size = dist_config.world_size + context = DistContext(rank=rank, + dp_rank=dp_rank, + dist_config=dist_config, + attn_tp_group=DistGroup(rank=0), + mlp_tp_group=DistGroup(rank=0), + moe_tp_group=DistGroup(rank=0), + tp_group=DistGroup(rank=0)) if world_size == 1: - return DistContext(dist_config=dist_config) + return context assert dist.is_initialized() - # world(assume world group is gloo) - world_cpu_group = dist.GroupMember.WORLD - - tp_rank = rank % tp # tp - tp_gpu_group = None - tp_gpu_groups = None - tp_cpu_group = None - tp_group_id = dp_rank // tp - if tp > 1: - # all tp groups should be created in all procs - ranks = range(world_size) - tp_gpu_groups = [] - tp_cpu_groups = [] - for start in range(0, world_size, tp): - tp_ranks = ranks[start:start + tp] - group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=ccl_backend) - tp_gpu_groups.append(group) - cpu_group = dist.new_group(ranks=tp_ranks, timeout=timeout, backend=cpu_backend) - tp_cpu_groups.append(cpu_group) - tp_gpu_group = tp_gpu_groups[tp_group_id] - tp_cpu_group = tp_cpu_groups[tp_group_id] - - ep_rank = rank % ep - ep_gpu_group = None - ep_gpu_groups = None - ep_group_id = dp_rank // ep - if ep > 1: - ranks = range(world_size) - ep_gpu_groups = [] - for start in range(0, world_size, ep): - ep_ranks = ranks[start:start + ep] - group = dist.new_group(ranks=ep_ranks, timeout=timeout, backend=ccl_backend) - ep_gpu_groups.append(group) - ep_gpu_group = ep_gpu_groups[ep_group_id] - - dp_cpu_group = None - if dp > 1: - dp_cpu_group = dist.new_group(ranks=range(dp), timeout=timeout, backend=cpu_backend) - - context = DistContext( - rank=rank, - world_size=world_size, - tp=tp, - dp=dp, - ep=ep, - tp_rank=tp_rank, - dp_rank=dp_rank, - ep_rank=ep_rank, - world_cpu_group=world_cpu_group, - tp_cpu_group=tp_cpu_group, - tp_gpu_group=tp_gpu_group, - tp_gpu_groups=tp_gpu_groups, - dp_cpu_group=dp_cpu_group, - dp_gpu_group=None, - ep_gpu_group=ep_gpu_group, - ep_gpu_groups=ep_gpu_groups, - dist_config=dist_config, - ) + _build_tp_group(context, timeout, cpu_backend=cpu_backend, ccl_backend=ccl_backend) + + # ep + cls._build_ep_group(context, timeout, ccl_backend=ccl_backend) + return context def close(self): """Close groups.""" if not dist.is_initialized(): return - if self.tp_gpu_groups is not None: - for group in self.tp_gpu_groups: - dist.destroy_process_group(group) + if self.attn_tp_group is not None: + self.attn_tp_group.close() + if self.mlp_tp_group is not None: + self.mlp_tp_group.close() + if self.moe_tp_group is not None: + self.moe_tp_group.close() if self.ep_gpu_groups is not None: for group in self.ep_gpu_groups: dist.destroy_process_group(group) + self.ep_gpu_groups = None DefaultContext = DistContext.build() @@ -141,6 +288,10 @@ def set_context(self, context: DistContext): """Set current context.""" self.t_local.device_context = context + def current_config(self) -> DistConfig: + """Get current dist config.""" + return self.current_context().dist_config + @contextmanager def context(self, context: DistContext): """Context manager.""" @@ -164,25 +315,34 @@ def get_dist_manager(): def get_world_rank(): """Get distributed world size and rank.""" ctx = get_dist_manager().current_context() - world_size = ctx.world_size + world_size = ctx.dist_config.world_size rank = ctx.rank return world_size, rank -def get_tp_world_rank(): +def get_tp_world_rank(layer_type: Optional[str] = None): ctx = get_dist_manager().current_context() - return ctx.tp, ctx.tp_rank + if layer_type is None: + return ctx.dist_config.tp, ctx.tp_group.rank + elif layer_type == 'attn': + return ctx.dist_config.attn_tp, ctx.attn_tp_group.rank + elif layer_type == 'mlp': + return ctx.dist_config.mlp_tp, ctx.mlp_tp_group.rank + elif layer_type == 'moe': + return ctx.dist_config.moe_tp, ctx.moe_tp_group.rank + else: + raise RuntimeError(f'Unknown layer type: {layer_type}') def get_dp_world_rank(): ctx = get_dist_manager().current_context() - return ctx.dp, ctx.dp_rank + return ctx.dist_config.dp, ctx.dp_rank def get_ep_world_rank(): ctx = get_dist_manager().current_context() - return ctx.ep, ctx.ep_rank + return ctx.dist_config.ep, ctx.ep_rank def _check_group_device(device: str): @@ -193,48 +353,41 @@ def _check_group_device(device: str): def get_process_group(device: str = None): """Get process group.""" - ctx = get_dist_manager().current_context() - if device is None: - return dist.GroupMember.WORLD + return dist.GroupMember.WORLD - _check_group_device(device) - if device == 'cpu': - return ctx.world_cpu_group +def get_dist_group(layer_type: str = 'attn'): + """Get dist group.""" + ctx = get_dist_manager().current_context() + if layer_type == 'attn': + tp_group = ctx.attn_tp_group + elif layer_type == 'mlp': + tp_group = ctx.mlp_tp_group + elif layer_type == 'moe': + tp_group = ctx.moe_tp_group else: - raise RuntimeError('gpu world group is not supported.') + raise RuntimeError(f'Unknown layer type: {layer_type}') + return tp_group -def get_tp_group(device: str = 'gpu'): +def get_tp_group(device: str = 'gpu', layer_type: str = 'attn'): """Get tp group.""" - ctx = get_dist_manager().current_context() - _check_group_device(device) + tp_group = get_dist_group(layer_type) - if device == 'cpu': - return ctx.tp_cpu_group - else: - return ctx.tp_gpu_group - - -def get_dp_group(device: str = 'gpu'): - """Get dp group.""" - ctx = get_dist_manager().current_context() - - _check_group_device(device) + if tp_group is None: + return None if device == 'cpu': - return ctx.dp_cpu_group + return tp_group.cpu_group else: - return ctx.dp_gpu_group + return tp_group.gpu_group def get_group(group_type: str, device: str): """Get group.""" if group_type == 'tp': return get_tp_group(device) - elif group_type == 'dp': - return get_dp_group(device) elif group_type in ['world', 'all']: return get_process_group(device) else: @@ -278,3 +431,27 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group='tp', async_op=Fal if isinstance(group, str): group = get_group(group, 'gpu') return dist.reduce_scatter(output, input_list, op=op, group=group, async_op=async_op) + + +def gather_by_tp_sizes(x: torch.Tensor, + tp_sizes: List[int], + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False): + """Gather input.""" + shape = (*x.shape[:-2], sum(tp_sizes), *x.shape[-1:]) + new_x = x.new_empty(shape) + split_new_x = list(new_x.split(tp_sizes, -2)) + handle = dist.all_gather(split_new_x, x, group=group, async_op=async_op) + if async_op: + return new_x, handle + return new_x + + +def reduce_scatter_by_tp_sizes(out: torch.Tensor, rank: int, tp_sizes: List[int], group: dist.ProcessGroup): + """Reduce scatter.""" + attn_tp = get_dist_manager().current_config().attn_tp + outs = list(out.split(tp_sizes, -2)) + outs = [item for item in outs for _ in range(attn_tp)] + out = outs[rank] + dist.reduce_scatter(out, outs, group=group) + return out diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 2179b5a99f..3133b3827d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -133,12 +133,7 @@ def _build_backend_config(engine_config: PytorchEngineConfig): def _build_dist_config(engine_config: PytorchEngineConfig): """Build dist config.""" - dist_config = DistConfig(dp=engine_config.dp, - tp=engine_config.tp, - ep=engine_config.ep, - dp_rank=engine_config.dp_rank, - enable_microbatch=engine_config.enable_microbatch, - enable_eplb=engine_config.enable_eplb) + dist_config = DistConfig.from_engine_config(engine_config=engine_config) return dist_config @@ -602,7 +597,9 @@ def __update_max_new_tokens(msg): sampling_param = msg.sampling_param max_new_tokens = sampling_param.max_new_tokens num_all_tokens = msg.num_valid_ids - if max_new_tokens + num_all_tokens > max_session_len: + if self.engine_config.role == EngineRole.Prefill: + sampling_param.max_new_tokens = 1 + elif max_new_tokens + num_all_tokens > max_session_len: logger.warning( f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. ' f'Update max_new_tokens={max_session_len - num_all_tokens}.') @@ -656,7 +653,7 @@ def model_config(self) -> ModelConfig: @property def gpu_count(self): - return self.tp * self.dp + return self.dist_config.world_size @property def torch_int_dtype(self): diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e50843a80..abc06094a3 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -37,7 +37,6 @@ def __init__(self, self.dist_config = dist_config self.misc_config = misc_config self.dp = dist_config.dp - self.tp = dist_config.tp self.world_size = dist_config.world_size self.device_type = device_type @@ -163,7 +162,7 @@ def update_configs(self): logger.debug(f'minimal free gpu memory: {free_mem >> 20} mb') vocal_size = self.model_config.vocab_size - tp = self.dist_config.attn_config.tp + tp = self.dist_config.attn_tp cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp, cache_config.quant_policy) runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size, vocal_size) diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index 327d56a5ca..c142a62610 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -199,8 +199,9 @@ def warmup_dist(self): from lmdeploy.pytorch.distributed import all_reduce, get_dist_manager with get_dist_manager().context(self.dist_ctx): + group = self.dist_ctx.tp_group.gpu_group tmp = torch.empty((1, ), device='cuda') - all_reduce(tmp) + all_reduce(tmp, group=group) def pack_output(self, output: Dict): """Pack output.""" @@ -242,14 +243,11 @@ def __init__(self, adapters=adapters, device_type=device_type) - self.dp_rank = dist_config.dp_rank device_ctx = DeviceContext(device_type) with get_device_manager().context(device_ctx): logger.info('Init ray cluster.') - ray_world_size = self.world_size - if self.dp > 1: - ray_world_size = 1 - self.ray_ctx = RayContext(ray_world_size, dp=dist_config.dp, device_type=device_type) + attn_tp = dist_config.attn_tp + self.ray_ctx = RayContext(attn_tp, dp=dist_config.dp, device_type=device_type) placement_group = self.ray_ctx.get_placement_group() self.placement_group = placement_group @@ -285,13 +283,11 @@ def __init__(self, self._init_distributed_environment_by_device(device_type) logger.info('Init distributed process group.') - if self.dp == 1: - ray.get([ - worker.init_process_group.remote(rank, self.master_addr, self.master_port) - for rank, worker in enumerate(self.workers) - ]) - else: - ray.get(self.workers[0].init_process_group.remote(self.dp_rank, self.master_addr, self.master_port)) + rank_offset = dist_config.dp_rank * attn_tp + ray.get([ + worker.init_process_group.remote(rank + rank_offset, self.master_addr, self.master_port) + for rank, worker in enumerate(self.workers) + ]) if self.dist_config.world_size > 1: logger.info('Warming up distribute environment, this might take long time, please waiting...') @@ -510,7 +506,8 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict for bundle_id, bundle in enumerate(placement_group.bundle_specs): if bundle.get(device_str, 0) and self._valid_bundle_id(bundle_id): bundle_indices.append(bundle_id) - bundle_indices = bundle_indices[:self.world_size] + attn_tp = self.dist_config.attn_tp + bundle_indices = bundle_indices[:attn_tp] workers = list() for _, bundle_id in enumerate(bundle_indices): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 43f42c93be..ec0aaeb4bc 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -3,7 +3,7 @@ import base64 import functools import time -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from dataclasses import dataclass, fields from multiprocessing.reduction import ForkingPickler from os import getenv @@ -118,7 +118,7 @@ def __init__(self, dist_ctx: DistContext, stream: torch.Stream): from lmdeploy.pytorch import envs self.rank = dist_ctx.rank self.dp_rank = dist_ctx.dp_rank - self.dp = dist_ctx.dp + self.dp = dist_ctx.dist_config.dp self.stream = stream self.profiler = None if self.dp == 1: @@ -335,6 +335,7 @@ def __init__(self, device = 'cuda' self.backend_config = backend_config self.misc_config = misc_config + self.dist_config = dist_ctx.dist_config rank = dist_ctx.rank self.model_path = model_path @@ -342,19 +343,18 @@ def __init__(self, self.device = device self.rank = rank - tp_rank = dist_ctx.tp_rank - tp = dist_ctx.tp - world_size = dist_ctx.world_size + tp = self.dist_config.tp + world_size = self.dist_config.world_size self.tp = tp self.world_size = world_size - self.tp_rank = tp_rank + self.need_output = rank % self.dist_config.attn_tp == 0 self.patched_model = None self.cache_engine = None self.profiler: AgentProfiler = None # microbatch - self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch + self.enable_microbatch = self.dist_config.enable_microbatch self.enable_microbatch_prefill_batchsize_threshold = \ int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2)) self.enable_microbatch_prefill_token_threshold = \ @@ -392,11 +392,11 @@ def get_free_mem(self): def warmup(self): """warmup.""" # TODO: disable for now, do not remove the comments. - with self.all_context(): + with self.all_context(), torch.cuda.stream(self.stream): max_batches = self.cache_config.max_batches num_tokens = max_batches - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dp = self.dist_config.dp + # warmup prefill inputs = self.inputs_strategy.make_dummy(max_batches, is_decoding=False, @@ -404,7 +404,10 @@ def warmup(self): vocab_size=self.model_config.vocab_size) if dp > 1: inputs.build_dp_meta() + logger.debug('Warmup prefill start.') self._forward_impl(inputs) + torch.cuda.synchronize() + logger.debug('Warmup prefill done.') # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() @@ -416,7 +419,10 @@ def warmup(self): vocab_size=self.model_config.vocab_size) if dp > 1: inputs.build_dp_meta() + logger.debug(f'Warmup decoding num_tokens={num_tokens} start.') self._forward_impl(inputs) + torch.cuda.synchronize() + logger.debug(f'Warmup decoding num_tokens={num_tokens} done.') def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): """Slice outputs.""" @@ -482,8 +488,8 @@ def get_output(self): async def __long_context_single_forward(new_inputs, max_seqlen: int): """One large sequence.""" - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp model_metas = new_inputs[0].model_metas output_gather = _OutputGather(max_seqlen) for inp in new_inputs: @@ -576,6 +582,58 @@ def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: Extr with self.agent_strategy.broadcast_next_token(next_token_ids, extra_inputs, dist_ctx) as handle: yield handle + @record_function('prepare_dp') + async def _prepare_dp(self, inputs: ModelInputs, sync_long_context: bool, is_dummy: bool): + """Prepare dp.""" + world_size = self.dist_config.world_size + is_decoding = inputs.is_decoding + + # gather dp forward metadata + batch_size = inputs.seq_length.numel() + dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)] + # check enable_microbatch + if self.enable_microbatch: + tokens_num = inputs.input_ids.numel() + if is_decoding: + enable_microbatch = batch_size >= \ + self.enable_microbatch_decode_batchsize_threshold + else: + enable_microbatch = batch_size >= \ + self.enable_microbatch_prefill_batchsize_threshold and \ + tokens_num >= self.enable_microbatch_prefill_token_threshold + dp_forward_meta.append(int(enable_microbatch)) + gathered_meta = DistGatherScalar(dp_forward_meta, world_size, device='cuda') + gathered_meta = (await gathered_meta.async_wait()).cpu() + + # check is_decoding + all_is_decoding = gathered_meta[:, 0] + assert all_is_decoding.sum().item() in [0, world_size] + + # check if all inputs are dummy inputs + is_all_dummy = gathered_meta[:, 1].all() + if is_all_dummy: + return inputs, sync_long_context, is_all_dummy + + if is_decoding: + all_batch_sizes = gathered_meta[:, 2] + padding_batch_size = all_batch_sizes.max().item() + meta = self.patched_model.get_meta() + meta.padding_batch_size = padding_batch_size + logger.debug(f'padding_batch_size={padding_batch_size}') + else: + all_sync_flags = gathered_meta[:, 3].bool() + sync_long_context = all_sync_flags.any() + logger.debug(f'sync_long_context={sync_long_context}') + + # update if enable_microbatch + if self.enable_microbatch and gathered_meta[:, 4].all(): + inputs.enable_microbatch = True + + # update dp meta + inputs.build_dp_meta() + inputs = self.patched_model.update_inputs(inputs) + return inputs, sync_long_context, is_all_dummy + async def _async_step_background( self, inputs: ModelInputs, @@ -590,7 +648,6 @@ async def _async_step_background( extra_inputs: ExtraInputs = None, ): """Asyc forward task.""" - dist_ctx = get_dist_manager().current_context() @record_function('update_inputs_for_next_step') def __update_inputs(next_token_ids, model_metas, extra_inputs): @@ -603,88 +660,29 @@ def __update_inputs(next_token_ids, model_metas, extra_inputs): extra_inputs=extra_inputs, ) - @asynccontextmanager - async def __prepare_dp(): - """Prepare dp.""" - if dp == 1: - yield - return - - nonlocal inputs, sync_long_context, is_all_dummy - - # gather dp forward metadata - batch_size = inputs.seq_length.numel() - dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)] - # check enable_microbatch - if self.enable_microbatch: - tokens_num = inputs.input_ids.numel() - if is_decoding: - enable_microbatch = batch_size >= \ - self.enable_microbatch_decode_batchsize_threshold - else: - enable_microbatch = batch_size >= \ - self.enable_microbatch_prefill_batchsize_threshold and \ - tokens_num >= self.enable_microbatch_prefill_token_threshold - dp_forward_meta.append(int(enable_microbatch)) - gathered_meta = DistGatherScalar(dp_forward_meta, dp, device='cuda') - - yield - - gathered_meta = (await gathered_meta.async_wait()).cpu() - - # check is_decoding - all_is_decoding = gathered_meta[:, 0] - assert all_is_decoding.sum().item() in [0, dp] - - # check if all inputs are dummy inputs - is_all_dummy = gathered_meta[:, 1].all() - if is_all_dummy: - return - - if is_decoding: - all_batch_sizes = gathered_meta[:, 2] - padding_batch_size = all_batch_sizes.max().item() - meta = self.patched_model.get_meta() - meta.padding_batch_size = padding_batch_size - logger.debug(f'padding_batch_size={padding_batch_size}') - else: - all_sync_flags = gathered_meta[:, 3].bool() - sync_long_context = all_sync_flags.any() - logger.debug(f'sync_long_context={sync_long_context}') - - # update if enable_microbatch - if self.enable_microbatch and gathered_meta[:, 4].all(): - inputs.enable_microbatch = True - - # update dp meta - inputs.build_dp_meta() - inputs = self.patched_model.update_inputs(inputs) - # dist tools dist_ctx = get_dist_manager().current_context() - rank = dist_ctx.rank - tp = dist_ctx.tp - dp = dist_ctx.dp + dist_config = dist_ctx.dist_config + rank = self.rank + tp = dist_config.attn_tp + dp = dist_config.dp sync_long_context = False if dp == 1 else sync_long_context is_decoding = inputs.is_decoding + # is_all_dummy would be updated in __prepare_dp + if dp > 1: + inputs, sync_long_context, is_all_dummy = await self._prepare_dp(inputs, sync_long_context, is_dummy) + + # skip dummy forward. + if is_all_dummy: + logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') + return + logger.debug(f' rank[{rank}]: ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)} ' f'is_decoding={inputs.is_decoding}') - # is_all_dummy would be updated in __prepare_dp - is_all_dummy = False - async with __prepare_dp(): - pass - - need_output = dp > 1 or rank % tp == 0 - - # skip dummy forward. - if is_all_dummy: - logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') - return - cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) for idx in range(loop_count): # inference @@ -694,9 +692,7 @@ async def __prepare_dp(): return_logits=return_logits, sync_long_context=sync_long_context, ) - logits = output['logits'] - logits = logits[0] # [bs, seq, prob] -> [seq, prob] - seq_length = inputs.seq_length + logits = output['logits'][0] # [bs, seq, prob] -> [seq, prob] seq_length = output.get('seq_length', inputs.seq_length) last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length) @@ -706,10 +702,10 @@ async def __prepare_dp(): if is_dummy: continue - need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) + need_broadcast_next = (tp > 1 and idx < loop_count - 1) # sampling and stopping - if need_output: + if self.need_output: logger.debug(f' rank[{rank}]: Sampling [{idx}].') # sampling next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) @@ -755,8 +751,8 @@ async def __prepare_dp(): async def _async_loop_background(self, forward_event: asyncio.Event = None): """Async loop background.""" with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode(): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp # for dp if dp > 1: @@ -842,7 +838,7 @@ def start(self, forward_event: asyncio.Event = None): def stop(self): """Stop task.""" - if self.dist_ctx.dp > 1: + if self.dist_config.dp > 1: return if self.profiler is not None: @@ -858,7 +854,7 @@ def stop(self): async def stop_async(self): """Stop task.""" - if self.dist_ctx.dp > 1: + if self.dist_config.dp > 1: return if self.profiler is not None: @@ -948,14 +944,14 @@ def build_graph_runner(self): def build_cache_engine(self): """Build cache engine.""" with self.all_context(): - dist_ctx = self.dist_ctx - attn_dist_cfg = dist_ctx.dist_config.attn_config - tp = attn_dist_cfg.tp + dist_ctx = get_dist_manager().current_context() + dist_cfg = self.dist_config + tp = dist_cfg.attn_tp self.cache_engine = CacheEngine(self.cache_config, self.model_config, rank=self.rank, - tp_rank=self.tp_rank, + tp_rank=dist_ctx.attn_tp_group.rank, world_size=tp, cache_stream=self.cache_stream) @@ -1009,7 +1005,7 @@ def _construct(item): with self.all_context(): serialized_data = request.serialized_named_tensors if isinstance(serialized_data, list): - serialized_data = serialized_data[self.dist_ctx.tp_rank] + serialized_data = serialized_data[self.dist_ctx.tp_group.rank] weights = ForkingPickler.loads(base64.b64decode(serialized_data)) weights = [(k, _construct(v)) for k, v in weights] self.patched_model.get_model().load_weights(weights) @@ -1092,6 +1088,7 @@ def __init__(self, model_agent: BaseModelAgent): self._next_inputs = None self._is_decoding = False self._ready_event = torch.cuda.Event() + self._attn_tp_cpu_group = self.dist_ctx.attn_tp_group.cpu_group def _make_dummy_forward_inputs(self): """Make dummy forward inputs.""" @@ -1119,6 +1116,70 @@ def _update_is_decoding(self, forward_inputs): if self.cache_config.role != EngineRole.Prefill: self._is_decoding = not self._is_decoding + async def _broadcast_has_inputs(self, has_inputs: bool = False): + """Broadcast has inputs.""" + attn_tp_group = self.dist_ctx.attn_tp_group + attn_tp = self.dist_ctx.dist_config.attn_tp + if attn_tp == 1: + return has_inputs + + group = attn_tp_group.cpu_group + rank = dist.get_global_rank(group, 0) + has_inputs = torch.tensor((has_inputs, )) + handle = dist.broadcast(has_inputs, src=rank, group=group, async_op=True) + future = handle.get_future() + while not future.done(): + await asyncio.sleep(0) + return has_inputs.item() + + async def _get_inputs_rank0(self): + """Try get inputs rank0.""" + try: + forward_inputs = await asyncio.wait_for(self._in_que.get(), timeout=0.02) + except asyncio.TimeoutError: + forward_inputs = None + + has_inputs = forward_inputs is not None + await self._broadcast_has_inputs(has_inputs) + return forward_inputs + + async def _get_inputs_rankn(self): + """Try get inputs rankn.""" + # broadcast + has_inputs = await self._broadcast_has_inputs() + + # try get inputs + if has_inputs: + forward_inputs = await self._in_que.get() + else: + forward_inputs = None + return forward_inputs + + async def _try_get_inputs(self): + """Try get inputs.""" + + attn_tp_group = self.dist_ctx.attn_tp_group + tp_rank = attn_tp_group.rank + + # initialize output + forward_inputs = None + need_dummy = True + + # get inputs from in_que. Rank 1 will not gather if rank 0 does not read inputs. + if tp_rank == 0: + forward_inputs = await self._get_inputs_rank0() + else: + forward_inputs = await self._get_inputs_rankn() + + if forward_inputs is not None: + model_inputs = forward_inputs['inputs'] + if model_inputs.is_decoding != self._is_decoding: + self._next_inputs = forward_inputs + else: + need_dummy = False + + return forward_inputs, need_dummy + async def get(self): """get.""" if self._next_inputs is not None: @@ -1132,16 +1193,7 @@ async def get(self): await asyncio.sleep(0.001) # try get inputs - need_dummy = True - try: - forward_inputs = await asyncio.wait_for(self._in_que.get(), timeout=0.02) - model_inputs = forward_inputs['inputs'] - if model_inputs.is_decoding != self._is_decoding: - self._next_inputs = forward_inputs - else: - need_dummy = False - except asyncio.TimeoutError: - pass + forward_inputs, need_dummy = await self._try_get_inputs() # make dummy inputs if need_dummy: diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index a377c9d4d6..6b9ba84dc6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -19,22 +19,41 @@ @dataclass class DPMeta: tp_sizes: List[int] = None - ep_sizes: List[int] = None + moe_tp_sizes: List[int] = None + + @staticmethod + def _gather_tp_sizes(tp: int, seqlen: int, dist_ctx: dist.DistContext, layer_type: str): + """Gather tp size.""" + attn_tp = dist_ctx.dist_config.attn_tp + if tp > 1 and tp != attn_tp: + dist_group = dist.get_dist_group(layer_type=layer_type) + gather_group = dist_group.gpu_gather_group + tp_sizes = [None for _ in range(gather_group.size())] + dist.all_gather_object(tp_sizes, seqlen, group=gather_group) + else: + tp_sizes = [seqlen] + return tp_sizes @classmethod def build(cls, seqlen: int): """Get dp meta.""" dist_ctx = dist.get_dist_manager().current_context() + dist_config = dist_ctx.dist_config + + mlp_tp = dist_config.mlp_tp + tp_sizes = cls._gather_tp_sizes(mlp_tp, seqlen, dist_ctx, layer_type='mlp') - tp = dist_ctx.tp - if tp > 1: - tp_sizes = [None for _ in range(tp)] - tp_group = dist.get_tp_group('gpu') - dist.all_gather_object(tp_sizes, seqlen, group=tp_group) + moe_tp = dist_config.moe_tp + if moe_tp == mlp_tp: + moe_tp_sizes = tp_sizes else: - tp_sizes = [seqlen] + moe_tp_sizes = cls._gather_tp_sizes(moe_tp, seqlen, dist_ctx, layer_type='moe') + + return DPMeta(tp_sizes=tp_sizes, moe_tp_sizes=moe_tp_sizes) - return DPMeta(tp_sizes=tp_sizes, ) + def sync_tp_size(self, tp_size: int): + self.tp_sizes = [tp_size] * len(self.tp_sizes) + self.moe_tp_sizes = [tp_size] * len(self.moe_tp_sizes) @dataclass diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index a10e5da520..c4bc09e876 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -341,16 +341,9 @@ def __init__(self, batch: int, in_features: int, out_features: int, dtype: torch self.dtype = dtype self.device = device - def _get_tp_world_rank(self): - """Get tp world rank.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp == 1: - return get_tp_world_rank() - return 1, 0 - def _update_batch(self, batch: int): """Update out features.""" - world_size, _ = self._get_tp_world_rank() + world_size, _ = get_tp_world_rank('attn') batch = batch // world_size return batch @@ -360,7 +353,7 @@ def create_weight(self, batch: int, in_features: int, out_features: int, dtype: def weight_loader(self, param: nn.Parameter, weight: torch.Tensor): """Weight loader.""" - world_size, rank = self._get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') weight = weight.chunk(world_size, 0)[rank] param.data.copy_(weight) @@ -521,11 +514,11 @@ def forward( attn_metadata: Any = None, ): """Rewrite of LlamaAttention.forward.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp > 1: + dist_config = get_dist_manager().current_config() + if dist_config.dp > 1: num_heads = self.num_heads else: - world_size = dist_ctx.world_size + world_size = dist_config.world_size num_heads = self.num_heads // world_size nope_size = self.kv_lora_rank q_len = hidden_states.size(1) @@ -678,9 +671,10 @@ def __init__(self, config: Any, layer_idx, dtype: torch.dtype = None, device: to self.topk_group = config.topk_group dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp - world_size = dist_ctx.world_size - moe_all_reduce = dp > 1 and dist_ctx.tp > 1 + dist_config = dist_ctx.dist_config + dp = dist_config.dp + world_size = dist_config.world_size + moe_all_reduce = dp > 1 and dist_config.tp > 1 if get_dist_manager().current_context().dist_config.enable_eplb: eplb_dispatch_info = EPLBManager.get_dispatch_info( ep_rank=dist_ctx.ep_rank, @@ -753,8 +747,8 @@ def __init__(self, super().__init__() quantization_config = getattr(config, 'quantization_config', None) if is_shared_expert: - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp + dist_config = get_dist_manager().current_config() + dp = dist_config.dp if dp == 1: # split weight, do all reduce in moe is_tp = True diff --git a/lmdeploy/pytorch/models/llama4.py b/lmdeploy/pytorch/models/llama4.py index 4a84f35d85..28389f970d 100644 --- a/lmdeploy/pytorch/models/llama4.py +++ b/lmdeploy/pytorch/models/llama4.py @@ -205,8 +205,8 @@ def __init__(self, config: Llama4TextConfig, dtype: torch.dtype = None, device: ) self.shared_expert = Llama4TextMLP(config, dtype=dtype, device=device, is_tp=True, all_reduce=False) - dist_ctx = dist.get_dist_manager().current_context() - self.tp = dist_ctx.tp + dist_config = dist.get_dist_manager().current_config() + self.tp = dist_config.tp def forward(self, hidden_states: torch.Tensor): """forward.""" diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index 464953f264..7818787f3e 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -6,7 +6,7 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank +from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.eplb import EPLBManager @@ -198,8 +198,6 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - world_size, _ = get_tp_world_rank() - _all_reduce = world_size > 1 if get_dist_manager().current_context().dist_config.enable_eplb: dist_ctx = get_dist_manager().current_context() self.eplb_dispatch_info = EPLBManager.get_dispatch_info( @@ -216,7 +214,7 @@ def __init__(self, dtype=dtype, device=device, quant_config=quantization_config, - all_reduce=_all_reduce, + all_reduce=True, layer_idx=layer_idx, ) diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py index 522d2aed95..190411b898 100644 --- a/lmdeploy/pytorch/models/sdar_moe.py +++ b/lmdeploy/pytorch/models/sdar_moe.py @@ -6,7 +6,6 @@ from torch import nn from transformers.configuration_utils import PretrainedConfig -from lmdeploy.pytorch.distributed import get_tp_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, @@ -180,8 +179,6 @@ def __init__(self, self.softmax_topk = SoftmaxTopK(self.top_k) - world_size, _ = get_tp_world_rank() - _all_reduce = world_size > 1 self.experts = build_fused_moe( self.hidden_dim, self.ffn_dim, @@ -191,7 +188,7 @@ def __init__(self, dtype=dtype, device=device, quant_config=quantization_config, - all_reduce=_all_reduce, + all_reduce=True, layer_idx=layer_idx, ) diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 7a1654db4b..52768591ff 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -2,7 +2,7 @@ import torch from torch import nn -from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank +from lmdeploy.pytorch.distributed import get_tp_world_rank from ..backends import OpType, get_backend from ..backends.attention import AttentionMetadata @@ -11,11 +11,7 @@ def _update_num_heads(num_heads: int, num_kv_heads: int): """Update heads.""" - dist_ctx = get_dist_manager().current_context() - if dist_ctx.dp == 1: - world_size, rank = get_tp_world_rank() - else: - world_size, rank = 1, 0 + world_size, rank = get_tp_world_rank('attn') num_heads = get_distribute_size(num_heads, world_size, rank) num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) return num_heads, num_kv_heads diff --git a/lmdeploy/pytorch/nn/linear/__init__.py b/lmdeploy/pytorch/nn/linear/__init__.py index e57eaed9e5..3738e921c4 100644 --- a/lmdeploy/pytorch/nn/linear/__init__.py +++ b/lmdeploy/pytorch/nn/linear/__init__.py @@ -4,7 +4,8 @@ import torch from torch import nn -from lmdeploy.pytorch.distributed import get_dist_manager, get_dp_world_rank, get_tp_world_rank +from lmdeploy.pytorch.config import TPMode +from lmdeploy.pytorch.distributed import get_dist_manager, get_tp_world_rank from .awq import AwqLinear, MergedAwqLinear, QKVAwqLinear from .blocked_fp8 import BlockedF8Linear, MergedBlockedF8Linear, QKVBlockedF8Linear @@ -13,51 +14,25 @@ from .w8a8 import MergedW8A8Linear, QKVW8A8Linear, W8A8Linear -def _is_dp_enabled(): - """Is dp.""" - return get_dp_world_rank()[0] > 1 - - -def _get_dp_gather(is_tp: bool): - """Get dp gather.""" - dp_gather = True - if not _is_dp_enabled(): - # disable if not dp - dp_gather = False - if not is_tp: - dp_gather = False - return dp_gather - - -def _get_dp_tp_meta(all_reduce: bool = True): - """Get tp meta.""" - dist_ctx = get_dist_manager().current_context() - dist_attn_cfg = dist_ctx.dist_config.attn_config - tp = dist_attn_cfg.tp - is_tp = tp > 1 - all_reduce = all_reduce if is_tp else False - return is_tp, all_reduce - - -def build_linear(in_features: int, - out_features: int, - bias: bool, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - colwise: bool = True, - is_tp: bool = False, - quant_config: Any = None, - all_reduce: bool = True, - tp_align_size: int = 1, - dp_gather: bool = False, - dp_scatter: bool = False) -> nn.Module: +def build_linear( + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + colwise: bool = True, + is_tp: bool = False, + quant_config: Any = None, + all_reduce: bool = True, + tp_align_size: int = 1, + dp_gather: bool = False, + layer_type: str = 'attn', +) -> nn.Module: """Build linear.""" - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - if not is_tp: - all_reduce = False - - if (dp_scatter or dp_gather) and quant_config is not None: + if layer_type is None: + layer_type = 'attn' + all_reduce = all_reduce if is_tp else False + if dp_gather and quant_config is not None: quant_method = quant_config['quant_method'] assert quant_method in ['fp8'], (f'Do not support dp_gather with quant_method={quant_method}') @@ -73,7 +48,7 @@ def build_linear(in_features: int, all_reduce=all_reduce, tp_align_size=tp_align_size, dp_gather=dp_gather, - dp_scatter=dp_scatter, + layer_type=layer_type, ) quant_method = quant_config['quant_method'] @@ -94,6 +69,7 @@ def build_linear(in_features: int, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce, + layer_type=layer_type, ) if quant_method == 'smooth_quant': return W8A8Linear(in_features, @@ -104,7 +80,8 @@ def build_linear(in_features: int, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) elif quant_method == 'fp8': fmt = quant_config.get('fmt', 'e4m3') if fmt == 'e4m3': @@ -124,7 +101,7 @@ def build_linear(in_features: int, is_tp=is_tp, all_reduce=all_reduce, dp_gather=dp_gather, - dp_scatter=dp_scatter, + layer_type=layer_type, ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -139,16 +116,20 @@ def build_colwise_linear(in_features: int, tp_align_size: int = 1, quant_config: Any = None, dp_disable_tp: bool = False, - dp_gather: bool = False) -> nn.Module: + dp_gather: bool = False, + check_dist: bool = True, + layer_type: str = 'attn') -> nn.Module: """Build columnwise parallel linear layer.""" - if dp_disable_tp and is_tp: - is_tp, _ = _get_dp_tp_meta() - elif is_tp: - is_tp = get_tp_world_rank()[0] > 1 + if check_dist: + dist_config = get_dist_manager().current_config() + tp, tp_mode = dist_config.get_tp_by_layer(layer_type) + + # check is_tp + is_tp = is_tp if tp > 1 else False + is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp - if dp_gather: - assert not dp_disable_tp - dp_gather = _get_dp_gather(is_tp) + # check dp_gather + dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False return build_linear(in_features=in_features, out_features=out_features, @@ -160,7 +141,8 @@ def build_colwise_linear(in_features: int, quant_config=quant_config, all_reduce=False, tp_align_size=tp_align_size, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) def build_rowwise_linear(in_features: int, @@ -173,10 +155,14 @@ def build_rowwise_linear(in_features: int, quant_config: Any = None, all_reduce: bool = True, dp_disable_tp: bool = False, - dp_scatter: bool = False) -> nn.Module: + check_dist: bool = True, + layer_type: str = 'attn') -> nn.Module: """Build rowwise parallel linear layer.""" - if dp_disable_tp and is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + if check_dist: + dist_config = get_dist_manager().current_config() + tp, _ = dist_config.get_tp_by_layer(layer_type) + is_tp = is_tp if tp > 1 else False + is_tp = False if (dp_disable_tp and dist_config.dp > 1) else is_tp return build_linear( in_features=in_features, out_features=out_features, @@ -188,7 +174,7 @@ def build_rowwise_linear(in_features: int, quant_config=quant_config, all_reduce=all_reduce, tp_align_size=tp_align_size, - dp_scatter=dp_scatter, + layer_type=layer_type, ) @@ -202,10 +188,12 @@ def build_merged_colwise_linear( is_tp: bool = True, out_names: List[Any] = None, dp_gather: bool = False, + check_dist: bool = True, + layer_type: str = 'attn', ): """Merge linear.""" - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 + if check_dist and is_tp: + is_tp = get_tp_world_rank(layer_type)[0] > 1 if dp_gather and quant_config is not None: quant_method = quant_config['quant_method'] @@ -219,7 +207,8 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, out_names=out_names, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) quant_method = quant_config['quant_method'] quant_dtype = torch.int8 @@ -237,6 +226,7 @@ def build_merged_colwise_linear( bias=bias, device=device, is_tp=is_tp, + layer_type=layer_type, ) if quant_method == 'smooth_quant': return MergedW8A8Linear(in_features=in_features, @@ -246,7 +236,8 @@ def build_merged_colwise_linear( device=device, is_tp=is_tp, out_names=out_names, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) elif quant_method == 'fp8': fmt = quant_config.get('fmt', 'e4m3') if fmt == 'e4m3': @@ -265,6 +256,7 @@ def build_merged_colwise_linear( is_tp=is_tp, out_names=out_names, dp_gather=dp_gather, + layer_type=layer_type, ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -280,19 +272,10 @@ def build_qkv_proj(in_features: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, is_tp: bool = True, - num_replicate_kv_heads: int = 1, - dp_disable_tp: bool = True, - all_reduce: bool = False, - dp_gather: bool = False): + num_replicate_kv_heads: int = 1): """Build qkv proj.""" - if dp_disable_tp and is_tp: - is_tp, _ = _get_dp_tp_meta(all_reduce) - elif is_tp: - is_tp = get_tp_world_rank()[0] > 1 - - if dp_gather: - assert not dp_disable_tp - dp_gather = _get_dp_gather(is_tp) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.attn_tp > 1 else False if head_size_v is None: head_size_v = head_size @@ -358,7 +341,7 @@ def build_qkv_proj(in_features: int, dtype=dtype, device=device, is_tp=is_tp, - dp_gather=dp_gather, + dp_gather=False, num_replicate_kv_heads=num_replicate_kv_heads) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -374,8 +357,8 @@ def build_o_proj(in_features: int, quant_config: Any = None, all_reduce: bool = True) -> nn.Module: """Build down linear.""" - if is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.attn_tp > 1 else False return build_rowwise_linear( in_features=in_features, @@ -387,6 +370,8 @@ def build_o_proj(in_features: int, tp_align_size=tp_align_size, quant_config=quant_config, all_reduce=all_reduce, + check_dist=False, + layer_type='attn', ) @@ -400,12 +385,10 @@ def build_gateup_linear(in_features: int, out_names: List[Any] = None, dp_gather: bool = True): """Build gate up linear.""" - if dp_gather: - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - dp_gather = _get_dp_gather(is_tp) - elif is_tp: - is_tp, _ = _get_dp_tp_meta() + dist_config = get_dist_manager().current_config() + tp, tp_mode = dist_config.get_tp_by_layer('mlp') + is_tp = is_tp if tp > 1 else False + dp_gather = dp_gather if is_tp and tp_mode == TPMode.DP_TP else False return build_merged_colwise_linear( in_features=in_features, @@ -417,6 +400,8 @@ def build_gateup_linear(in_features: int, is_tp=is_tp, out_names=out_names, dp_gather=dp_gather, + check_dist=False, + layer_type='mlp', ) @@ -428,19 +413,10 @@ def build_down_linear(in_features: int, is_tp: bool = False, tp_align_size: int = 1, quant_config: Any = None, - all_reduce: bool = True, - dp_scatter: bool = True) -> nn.Module: + all_reduce: bool = True) -> nn.Module: """Build down linear.""" - if dp_scatter: - if is_tp: - is_tp = get_tp_world_rank()[0] > 1 - if not _is_dp_enabled(): - # disable if not dp - dp_scatter = False - if not is_tp: - dp_scatter = False - elif is_tp: - is_tp, all_reduce = _get_dp_tp_meta(all_reduce) + dist_config = get_dist_manager().current_config() + is_tp = is_tp if dist_config.mlp_tp > 1 else False return build_rowwise_linear( in_features=in_features, @@ -452,5 +428,6 @@ def build_down_linear(in_features: int, tp_align_size=tp_align_size, quant_config=quant_config, all_reduce=all_reduce, - dp_scatter=dp_scatter, + check_dist=False, + layer_type='mlp', ) diff --git a/lmdeploy/pytorch/nn/linear/awq.py b/lmdeploy/pytorch/nn/linear/awq.py index ede2cea60c..5e24d93db7 100644 --- a/lmdeploy/pytorch/nn/linear/awq.py +++ b/lmdeploy/pytorch/nn/linear/awq.py @@ -8,7 +8,7 @@ from ..utils import chunk_aligned, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class AwqLinear(LinearBase): @@ -25,8 +25,14 @@ def __init__( colwise: bool = True, is_tp: bool = False, all_reduce: bool = True, + layer_type: str = 'attn', ): - super().__init__(dtype=torch.float16, device=device, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce) + super().__init__(dtype=torch.float16, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, w_bit, group_size, colwise) qweight, scales, qzeros, bias = self.create_weights(in_features, out_features, w_bit, group_size, bias, @@ -78,7 +84,7 @@ def register_all_parameters(self, def _get_io_features(self, in_features: int, out_features: int, w_bit: int, group_size: int, colwise: bool): """Get io features.""" align = max(32 // w_bit, group_size) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank, align=align) else: @@ -128,7 +134,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -159,7 +165,7 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce) + return self.impl.forward(x, self.qweight, self.scales, self.qzeros, self.bias, all_reduce, group=self.tp_group) class MergedAwqLinear(AwqLinear): @@ -173,10 +179,11 @@ def __init__(self, bias: bool, device: Optional[torch.device] = None, is_tp: bool = True, - out_names: Optional[List[int]] = None): + out_names: Optional[List[int]] = None, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section_s = all_out_features - self.is_tp = is_tp elem_per_int = 32 // w_bit self.split_section_wz = [size // elem_per_int for size in all_out_features] @@ -187,7 +194,15 @@ def __init__(self, assert len(out_names) == len(self.all_out_features) self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names)) out_features = sum(all_out_features) - super().__init__(in_features, out_features, w_bit, group_size, bias, device, colwise=True, is_tp=is_tp) + super().__init__(in_features, + out_features, + w_bit, + group_size, + bias, + device, + colwise=True, + is_tp=is_tp, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -212,7 +227,7 @@ def _get_io_features(self, in_features: int, out_features: int, w_bit: int, grou def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] align = max(32 // w_bit, group_size) for out_feat in all_out_features: @@ -222,7 +237,7 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, grou def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] if loaded_weight.dim() == 1: # bias @@ -268,13 +283,16 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, num_replicate_kv_heads: int = 1): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) elem_per_int = 32 // w_bit self.qkv_split_section_s = self.qkv_split_section @@ -288,7 +306,8 @@ def __init__(self, bias=bias, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int], w_bit: int, group_size: int): """Update all out features.""" @@ -296,7 +315,7 @@ def _update_all_out_features(self, all_out_features: List[int], w_bit: int, grou def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] diff --git a/lmdeploy/pytorch/nn/linear/base.py b/lmdeploy/pytorch/nn/linear/base.py index e6711f271e..0229e4b67d 100644 --- a/lmdeploy/pytorch/nn/linear/base.py +++ b/lmdeploy/pytorch/nn/linear/base.py @@ -1,38 +1,102 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.distributed as dist from torch import nn -from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.config import TPMode +from lmdeploy.pytorch.distributed import (gather_by_tp_sizes, get_dist_group, get_dist_manager, get_tp_world_rank, + reduce_scatter_by_tp_sizes) from lmdeploy.pytorch.model_inputs import get_step_ctx_manager from .utils import update_tp_args -def _gather_input(x: torch.Tensor, tp_sizes: List[int]): - """Gather input.""" - shape0 = x.shape[:-2] - shape1 = x.shape[-1:] - shapes = [shape0 + (size, ) + shape1 for size in tp_sizes] - new_x = [x.new_empty(shape) for shape in shapes] - dist.all_gather(new_x, x) - x = torch.cat(new_x, dim=-2) - return x - - -def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): - """Reduce scatter.""" - _, rank = get_tp_world_rank() - out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() - outs = out.split(tp_sizes, 0) - out = outs[rank] - dist.reduce_scatter(out, outs) - out = out.transpose(0, -2) - return out +class LinearForwardDPTP: + + def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): + """Linear forward dp tp.""" + self.gemm_func = gemm_func + self.dist_ctx = get_dist_manager().current_context() + self.dist_config = self.dist_ctx.dist_config + self.tp = self.dist_config.mlp_tp + self.attn_tp = self.dist_config.attn_tp + + tp_group = self.dist_ctx.mlp_tp_group + self.rank = tp_group.rank + self.gather_rank = self.rank // self.attn_tp + self.gather_group = tp_group.gpu_gather_group + self.tp_group = tp_group.gpu_group + self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 + + def all_gather(self, hidden_states: torch.Tensor, tp_sizes: List[int]): + """All gather.""" + hidden_states, handle = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) + return hidden_states, handle + + def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): + """Reduce scatter.""" + hidden_states_list = list(hidden_states.split(tp_sizes, -2)) + hidden_states_list[self.gather_rank] = out_states + hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] + handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) + return out_states, handle + + def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, output_states: torch.Tensor, tp_sizes: List[int], + handle: dist.Work): + """Gemm and reduce scatter.""" + handle.wait() + cur_out = self.gemm_func(hidden_states) + cur_out_states = cur_out.split(tp_sizes, dim=0)[self.gather_rank] + output_states.copy_(cur_out_states) + return self.reduce_scatter(cur_out, output_states, tp_sizes) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + + def __slice_tensor(tensor: torch.Tensor, slice_size: int): + """Slice tensor.""" + cur_tensor = tensor[:slice_size] + tensor = tensor[slice_size:] + return cur_tensor, tensor + + def __slice_and_gather(): + """Slice and gather.""" + nonlocal hidden_states, tp_sizes, output_states + cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) + tp_sizes -= cur_tp_sizes + cur_tp_sizes = cur_tp_sizes.tolist() + + slice_size = cur_tp_sizes[self.gather_rank] + cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) + cur_output, output_states = __slice_tensor(output_states, slice_size) + + # all gather + cur_hidden_states, handle = self.all_gather(cur_hidden_states, cur_tp_sizes) + return dict(hidden_states=cur_hidden_states, output_states=cur_output, handle=handle, tp_sizes=cur_tp_sizes) + + step_ctx = get_step_ctx_manager().current_context() + tp_sizes = step_ctx.dp_meta.moe_tp_sizes + tp_sizes = torch.tensor(tp_sizes) + max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) + + output_states = torch.empty_like(hidden_states) + return_states = output_states + + # pre + cur_inputs = __slice_and_gather() + + # main loop + while tp_sizes.sum() > 0: + next_inputs = __slice_and_gather() + self._gemm_and_reduce_scatter(**cur_inputs) + cur_inputs = next_inputs + + # post + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + handle.wait() + return return_states class LinearBase(nn.Module): @@ -47,25 +111,64 @@ def __init__( all_reduce: bool = True, tp_align_size: int = 1, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__() - is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise) + self.init_tp_args(is_tp, all_reduce, colwise, layer_type) self.colwise = colwise - self.is_tp = is_tp - self.all_reduce = all_reduce self.tp_align_size = tp_align_size self.dp_gather = dp_gather - self.dp_scatter = dp_scatter if device is None: device = torch.device('cpu') if dtype is None: dtype = torch.float16 self.device = device self.dtype = dtype + self.layer_type = layer_type self.lora_adapters = nn.ModuleDict() + def init_tp_args(self, is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str): + if getattr(self, '_tp_args_initialized', False): + return + is_tp, all_reduce = update_tp_args(is_tp, all_reduce, colwise, layer_type=layer_type) + self.is_tp = is_tp + self.all_reduce = all_reduce + if is_tp: + dist_cfg = get_dist_manager().current_config() + _, rank = get_tp_world_rank(layer_type) + tp, tp_mode = dist_cfg.get_tp_by_layer(layer_type) + self.tp_rank = rank + self.tp = tp + self.tp_mode = tp_mode + dist_group = get_dist_group(layer_type=layer_type) + self.tp_group = dist_group.gpu_group + self.gather_group = dist_group.gpu_gather_group + else: + self.tp_rank = 0 + self.tp = 1 + self.tp_mode = TPMode.DEFAULT + self.tp_group = None + self.gather_group = None + + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def _gemm_func(self, x): + out = self._forward_default(x, False, None) + + for lora_adapter in self.lora_adapters.values(): + out = lora_adapter(x, out) + return out + + self.linear_dptp_forward = LinearForwardDPTP(_gemm_func) + + self._tp_args_initialized = True + + def get_tp_world_rank(self): + """Get tp world rank.""" + assert hasattr(self, 'tp') and hasattr(self, 'tp_rank'), 'Please run init_tp_args first.' + return self.tp, self.tp_rank + def update_weights(self): """Update weights.""" raise NotImplementedError('This method should be implemented in subclasses.') @@ -81,24 +184,35 @@ def _forward_lora(self, x, tp_sizes: List[int]): for lora_adapter in self.lora_adapters.values(): out = lora_adapter(x, out) if self.all_reduce: - if self.dp_scatter: - out = _reduce_scatter_input(out, tp_sizes) + if self.tp_mode == TPMode.DP_TP: + out = reduce_scatter_by_tp_sizes(out, self.tp_rank, tp_sizes, group=self.tp_group) else: - dist.all_reduce(out) + dist.all_reduce(out, group=self.tp_group) return out - def forward(self, x): - """Forward of linear layer.""" - tp_sizes = None - if self.dp_gather or self.dp_scatter: - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - tp_sizes = dp_meta.tp_sizes + def _forward_dp_tp(self, x): + """Forward dp_tp.""" + if self.dp_gather and self.all_reduce: + return self.linear_dptp_forward.forward(x) + + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_sizes = dp_meta.tp_sizes if self.dp_gather: - x = _gather_input(x, tp_sizes) + x = gather_by_tp_sizes(x, tp_sizes, group=self.gather_group) if len(self.lora_adapters) == 0: return self._forward_default(x, self.all_reduce, tp_sizes) else: return self._forward_lora(x, tp_sizes) + + def forward(self, x): + """Forward of linear layer.""" + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self._forward_dp_tp(x) + + if len(self.lora_adapters) == 0: + return self._forward_default(x, self.all_reduce, None) + else: + return self._forward_lora(x) diff --git a/lmdeploy/pytorch/nn/linear/blocked_fp8.py b/lmdeploy/pytorch/nn/linear/blocked_fp8.py index 0638bf7650..af67358bc3 100644 --- a/lmdeploy/pytorch/nn/linear/blocked_fp8.py +++ b/lmdeploy/pytorch/nn/linear/blocked_fp8.py @@ -4,13 +4,13 @@ import torch from lmdeploy.pytorch.backends import OpType, get_backend -from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader from ..quant_utils import quant_blocked_fp8 from ..utils import div_up, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class BlockedF8Linear(LinearBase): @@ -28,7 +28,7 @@ def __init__( is_tp: bool = False, all_reduce: bool = True, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__(dtype=dtype, device=device, @@ -36,7 +36,7 @@ def __init__( is_tp=is_tp, all_reduce=all_reduce, dp_gather=dp_gather, - dp_scatter=dp_scatter) + layer_type=layer_type) self.block_size = 128 self.fp8_dtype = fp8_dtype if self.is_tp: @@ -76,7 +76,7 @@ def register_all_parameters(self, def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = get_tp_world_rank() + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank) else: @@ -107,7 +107,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -142,11 +142,18 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - if self.dp_scatter: - _, rank = get_tp_world_rank() - return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, rank, tp_sizes) + if self.tp_mode == TPMode.DP_TP: + rank = self.tp_rank + return self.impl.forward(x, + self.weight, + self.weight_scale_inv, + self.bias, + all_reduce, + group=self.tp_group, + rank=rank, + scatter_size=tp_sizes) else: - return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.weight_scale_inv, self.bias, all_reduce, group=self.tp_group) class MergedBlockedF8Linear(BlockedF8Linear): @@ -162,12 +169,13 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - dp_gather: bool = False): + dp_gather: bool = False, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) if replicate is None: replicate = tuple(False for _ in all_out_features) self.block_size = 128 self.split_section = all_out_features - self.is_tp = is_tp self.scale_split_section = [section // self.block_size for section in self.split_section] all_out_features = self._update_all_out_features(all_out_features, replicate) self.all_out_features = all_out_features @@ -185,7 +193,8 @@ def __init__(self, fp8_dtype=fp8_dtype, colwise=True, is_tp=is_tp, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -207,7 +216,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat, rep in zip(all_out_features, replicate): if rep: @@ -218,7 +227,7 @@ def _update_all_out_features(self, all_out_features: List[int], replicate: Optio def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] if loaded_weight.dim() == 2 and loaded_weight.dtype != self.fp8_dtype: loaded_weight = loaded_weight.to(torch.float32) @@ -267,13 +276,16 @@ def __init__(self, dp_gather: bool = False, num_replicate_kv_heads: int = 1): self.block_size = 128 + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -285,7 +297,8 @@ def __init__(self, device=device, is_tp=is_tp, out_names=out_names, - dp_gather=dp_gather) + dp_gather=dp_gather, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int], replicate: Optional[List[bool]]): """Update all out features.""" @@ -293,7 +306,7 @@ def _update_all_out_features(self, all_out_features: List[int], replicate: Optio def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - _, rank = _get_tp_world_rank(self.is_tp) + _, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] num_head = self.num_q_heads if shard_id == 'q' \ diff --git a/lmdeploy/pytorch/nn/linear/default.py b/lmdeploy/pytorch/nn/linear/default.py index a2556710a6..a3f8a31a2c 100644 --- a/lmdeploy/pytorch/nn/linear/default.py +++ b/lmdeploy/pytorch/nn/linear/default.py @@ -4,12 +4,12 @@ import torch from lmdeploy.pytorch.backends import OpType, get_backend -from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader from ..utils import chunk_aligned, get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class BaseLinear(LinearBase): @@ -27,7 +27,7 @@ def __init__( all_reduce: bool = True, tp_align_size: int = 1, dp_gather: bool = False, - dp_scatter: bool = False, + layer_type: str = 'attn', ): super().__init__(dtype=dtype, device=device, @@ -36,7 +36,7 @@ def __init__( all_reduce=all_reduce, tp_align_size=tp_align_size, dp_gather=dp_gather, - dp_scatter=dp_scatter) + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.Linear) @@ -64,7 +64,7 @@ def register_all_parameters(self, weight: torch.Tensor, bias: Optional[torch.Ten def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank, align=self.tp_align_size) else: @@ -95,7 +95,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -117,11 +117,17 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - if self.dp_scatter: - _, rank = get_tp_world_rank() - return self.impl.forward(x, self.weight, self.bias, all_reduce, rank, tp_sizes) + if self.tp_mode == TPMode.DP_TP: + rank = self.tp_rank + return self.impl.forward(x, + self.weight, + self.bias, + all_reduce, + group=self.tp_group, + rank=rank, + scatter_size=tp_sizes) else: - return self.impl.forward(x, self.weight, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.bias, all_reduce, group=self.tp_group) class MergedBaseLinear(BaseLinear): @@ -135,9 +141,10 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - dp_gather: bool = False): + dp_gather: bool = False, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section = all_out_features - self.is_tp = is_tp all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features if out_names is None: @@ -145,7 +152,15 @@ def __init__(self, assert len(out_names) == len(self.all_out_features) self.out_names_map = dict((name, idx) for idx, name in enumerate(out_names)) out_features = sum(all_out_features) - super().__init__(in_features, out_features, bias, dtype, device, colwise=True, is_tp=is_tp, dp_gather=dp_gather) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + colwise=True, + is_tp=is_tp, + dp_gather=dp_gather, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -162,7 +177,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) @@ -171,7 +186,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] loaded_weight = loaded_weight.chunk(world_size, 0)[rank] @@ -199,13 +214,16 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, num_replicate_kv_heads: int = 1): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -215,7 +233,8 @@ def __init__(self, dtype=dtype, device=device, is_tp=is_tp, - out_names=out_names) + out_names=out_names, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" @@ -223,7 +242,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() chunk_size, chunk_idx = world_size, rank shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] diff --git a/lmdeploy/pytorch/nn/linear/utils.py b/lmdeploy/pytorch/nn/linear/utils.py index 433a26c720..3e8c7db6f6 100644 --- a/lmdeploy/pytorch/nn/linear/utils.py +++ b/lmdeploy/pytorch/nn/linear/utils.py @@ -17,18 +17,10 @@ def check_qkv_split_layout(layout: str): f'but get: {layout}') -def _get_tp_world_rank(is_tp: bool): - """Get tp world size.""" - if is_tp: - return get_tp_world_rank() - else: - return 1, 0 - - -def update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool): +def update_tp_args(is_tp: bool, all_reduce: bool, colwise: bool, layer_type: str = 'attn'): """Update tp args according to the environment.""" if is_tp: - world, _ = get_tp_world_rank() + world, _ = get_tp_world_rank(layer_type) is_tp = world > 1 if not is_tp or colwise: @@ -46,10 +38,12 @@ def __init__(self, head_size: int, head_size_v: int, num_replicate_kv_heads: int = 1, - is_tp: bool = False): + is_tp: bool = False, + tp: int = 1, + tp_rank: int = 0): qkv_split_section = self._get_qkv_out_features(num_q_heads, num_kv_heads, head_size, head_size_v, num_replicate_kv_heads) - num_q_heads, num_kv_heads = self._update_num_heads(is_tp, num_q_heads, num_kv_heads) + num_q_heads, num_kv_heads = self._update_num_heads(is_tp, tp, tp_rank, num_q_heads, num_kv_heads) self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_size = head_size @@ -72,11 +66,11 @@ def _get_qkv_out_features(self, all_out_features = (num_q_heads * head_size, num_kv_heads_real * head_size, num_kv_heads_real * head_size_v) return all_out_features - def _update_num_heads(self, is_tp: bool, num_q_heads: int, num_kv_heads: int): + def _update_num_heads(self, is_tp: bool, tp: int, tp_rank: int, num_q_heads: int, num_kv_heads: int): """Update num heads.""" if not is_tp: return num_q_heads, num_kv_heads - world_size, rank = get_tp_world_rank() + world_size, rank = tp, tp_rank num_q_heads = get_distribute_size(num_q_heads, world_size, rank) num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) diff --git a/lmdeploy/pytorch/nn/linear/w8a8.py b/lmdeploy/pytorch/nn/linear/w8a8.py index d5d0d476f7..c9105e5599 100644 --- a/lmdeploy/pytorch/nn/linear/w8a8.py +++ b/lmdeploy/pytorch/nn/linear/w8a8.py @@ -8,7 +8,7 @@ from ..utils import get_distribute_size from .base import LinearBase -from .utils import QKVMixin, _get_tp_world_rank, check_qkv_split_layout +from .utils import QKVMixin, check_qkv_split_layout class W8A8Linear(LinearBase): @@ -23,8 +23,14 @@ def __init__(self, colwise: bool = True, is_tp: bool = False, all_reduce: bool = True, - quant_dtype: Optional[torch.dtype] = torch.int8): - super().__init__(dtype=torch.float16, device=device, colwise=colwise, is_tp=is_tp, all_reduce=all_reduce) + quant_dtype: Optional[torch.dtype] = torch.int8, + layer_type: str = 'attn'): + super().__init__(dtype=torch.float16, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + layer_type=layer_type) if self.is_tp: in_features, out_features = self._get_io_features(in_features, out_features, colwise) impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8) @@ -60,7 +66,7 @@ def register_all_parameters(self, weight: torch.Tensor, scale: torch.Tensor, bia def _get_io_features(self, in_features: int, out_features: int, colwise: bool): """Get io features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if colwise: out_features = get_distribute_size(out_features, world_size, rank) else: @@ -94,7 +100,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): if not self.is_tp: return default_weight_loader(param, loaded_weight) - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() if self.colwise: return self._weight_loader_tp_colwise(param, loaded_weight, rank, world_size) else: @@ -117,7 +123,7 @@ def update_weights(self): def _forward_default(self, x, all_reduce, tp_sizes): """Default forward implement.""" - return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce) + return self.impl.forward(x, self.weight, self.scale, self.bias, all_reduce, group=self.tp_group) class MergedW8A8Linear(W8A8Linear): @@ -131,9 +137,10 @@ def __init__(self, device: Optional[torch.device] = None, is_tp: bool = True, out_names: Optional[List[int]] = None, - quant_dtype: torch.dtype = torch.int8): + quant_dtype: torch.dtype = torch.int8, + layer_type: str = 'attn'): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type=layer_type) self.split_section = all_out_features - self.is_tp = is_tp all_out_features = self._update_all_out_features(all_out_features) self.all_out_features = all_out_features if out_names is None: @@ -148,7 +155,8 @@ def __init__(self, device, colwise=True, is_tp=is_tp, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type=layer_type) self.setup_loaders() def setup_loaders(self): @@ -167,7 +175,7 @@ def _get_io_features(self, in_features: int, out_features: int, colwise: bool): def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() new_all_out_features = [] for out_feat in all_out_features: new_out_feat = get_distribute_size(out_feat, world_size, rank) @@ -176,7 +184,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - world_size, rank = _get_tp_world_rank(self.is_tp) + world_size, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] loaded_weight = loaded_weight.chunk(world_size, 0)[rank] @@ -205,13 +213,16 @@ def __init__(self, is_tp: bool = True, num_replicate_kv_heads: int = 1, quant_dtype: torch.dtype = torch.int8): + self.init_tp_args(is_tp, all_reduce=False, colwise=True, layer_type='attn') QKVMixin.__init__(self, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, head_size=head_size, head_size_v=head_size_v, num_replicate_kv_heads=num_replicate_kv_heads, - is_tp=is_tp) + is_tp=is_tp, + tp=self.tp, + tp_rank=self.tp_rank) all_out_features = self.get_qkv_out_feautures() out_names = ('q', 'k', 'v') @@ -222,7 +233,8 @@ def __init__(self, device=device, is_tp=is_tp, out_names=out_names, - quant_dtype=quant_dtype) + quant_dtype=quant_dtype, + layer_type='attn') def _update_all_out_features(self, all_out_features: List[int]): """Update all out features.""" @@ -230,7 +242,7 @@ def _update_all_out_features(self, all_out_features: List[int]): def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, shard_id: Any): """Weight loader.""" - _, rank = _get_tp_world_rank(self.is_tp) + _, rank = self.get_tp_world_rank() shard_idx = self.out_names_map[shard_id] param_w = param.data.split(self.all_out_features, 0)[shard_idx] num_head = self.num_q_heads if shard_id == 'q' \ diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 0f88cdf851..5012969996 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -7,6 +7,7 @@ from torch import nn import lmdeploy.pytorch.distributed as dist +from lmdeploy.pytorch.config import TPMode from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank from lmdeploy.pytorch.model_inputs import get_step_ctx_manager @@ -47,12 +48,119 @@ def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int, dtype: t def _update_args(hidden_dim: int, ffn_dim: int): """Update args.""" - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') assert ffn_dim % world_size == 0 ffn_dim = ffn_dim // world_size return hidden_dim, ffn_dim +def _split_size(size: int, world_size: int, align: int): + size = size // align + assert size >= world_size + base = size // world_size + remain = size % world_size + split_size = [base + 1] * remain + [base] * (world_size - remain) + split_size = [s * align for s in split_size] + return split_size + + +class MoEForwardDPTP: + + def __init__(self, gemm_func: Callable, max_tokens_per_round: int = 8192): + """MoE forward dp tp.""" + self.gemm_func = gemm_func + self.dist_ctx = get_dist_manager().current_context() + self.dist_config = self.dist_ctx.dist_config + self.tp = self.dist_config.moe_tp + self.attn_tp = self.dist_config.attn_tp + + tp_group = self.dist_ctx.moe_tp_group + self.rank = tp_group.rank + self.gather_rank = self.rank // self.attn_tp + self.gather_group = tp_group.gpu_gather_group + self.tp_group = tp_group.gpu_group + self.max_tokens_per_round = max_tokens_per_round * self.attn_tp // self.tp // 2 + + def all_gather(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + tp_sizes: List[int]): + """All gather.""" + hidden_states, _ = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=self.gather_group, async_op=True) + topk_weights, _ = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=self.gather_group, async_op=True) + topk_ids, handle = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=self.gather_group, async_op=True) + return hidden_states, topk_weights, topk_ids, handle + + def reduce_scatter(self, hidden_states: torch.Tensor, out_states: torch.Tensor, tp_sizes: List[int]): + """Reduce scatter.""" + hidden_states_list = list(hidden_states.split(tp_sizes, -2)) + hidden_states_list[self.gather_rank] = out_states + hidden_states_list = [item for item in hidden_states_list for _ in range(self.attn_tp)] + handle = dist.reduce_scatter(out_states, hidden_states_list, group=self.tp_group, async_op=True) + return out_states, handle + + def _gemm_and_reduce_scatter(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + output_states: torch.Tensor, tp_sizes: List[int], handle: dist.Work): + """Gemm and reduce scatter.""" + handle.wait() + cur_out = self.gemm_func(hidden_states, topk_weights, topk_ids) + cur_out_states = cur_out.split(tp_sizes, dim=0)[self.gather_rank] + output_states.copy_(cur_out_states) + return self.reduce_scatter(cur_out, output_states, tp_sizes) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor): + """forward.""" + + def __slice_tensor(tensor: torch.Tensor, slice_size: int): + """Slice tensor.""" + cur_tensor = tensor[:slice_size] + tensor = tensor[slice_size:] + return cur_tensor, tensor + + def __slice_and_gather(): + """Slice and gather.""" + nonlocal hidden_states, topk_weights, topk_ids, tp_sizes, output_states + cur_tp_sizes = tp_sizes.minimum(max_tokens_per_round) + tp_sizes -= cur_tp_sizes + cur_tp_sizes = cur_tp_sizes.tolist() + + slice_size = cur_tp_sizes[self.gather_rank] + cur_hidden_states, hidden_states = __slice_tensor(hidden_states, slice_size) + cur_topk_weights, topk_weights = __slice_tensor(topk_weights, slice_size) + cur_topk_ids, topk_ids = __slice_tensor(topk_ids, slice_size) + cur_output, output_states = __slice_tensor(output_states, slice_size) + + # all gather + cur_hidden_states, cur_topk_weights, cur_topk_ids, handle = self.all_gather( + cur_hidden_states, cur_topk_weights, cur_topk_ids, cur_tp_sizes) + return dict(hidden_states=cur_hidden_states, + topk_weights=cur_topk_weights, + topk_ids=cur_topk_ids, + output_states=cur_output, + handle=handle, + tp_sizes=cur_tp_sizes) + + step_ctx = get_step_ctx_manager().current_context() + tp_sizes = step_ctx.dp_meta.moe_tp_sizes + tp_sizes = torch.tensor(tp_sizes) + max_tokens_per_round = tp_sizes.new_tensor(self.max_tokens_per_round) + + output_states = torch.empty_like(hidden_states) + return_states = output_states + + # pre + cur_inputs = __slice_and_gather() + + # main loop + while tp_sizes.sum() > 0: + next_inputs = __slice_and_gather() + self._gemm_and_reduce_scatter(**cur_inputs) + cur_inputs = next_inputs + + # post + _, handle = self._gemm_and_reduce_scatter(**cur_inputs) + handle.wait() + return return_states + + class LinearWeights(nn.Module): """Fused moe linear weights.""" @@ -107,7 +215,7 @@ def update_weight(self, weight: torch.Tensor): def weight_loader_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') if shard_id == 'gate': param_data = param.data[expert_id, :self.half_out] weight = loaded_weight.chunk(world_size, dim=0)[rank] @@ -147,67 +255,44 @@ def weight_loader_ep(self, param: torch.nn.Parameter, loaded_weight: torch.Tenso param_data.copy_(loaded_weight) -def _gather_input(x: torch.Tensor, tp_sizes: List[int]): - """Gather input.""" - shape0 = x.shape[:-2] - shape1 = x.shape[-1:] - shapes = [shape0 + (size, ) + shape1 for size in tp_sizes] - new_x = [x.new_empty(shape) for shape in shapes] - dist.all_gather(new_x, x) - x = torch.cat(new_x, dim=-2) - return x - - -def _reduce_scatter_input(out: torch.Tensor, tp_sizes: List[int]): - """Reduce scatter.""" - _, rank = get_tp_world_rank() - out = out.transpose(0, -2) - if not out.is_contiguous(): - out = out.contiguous() - outs = out.split(tp_sizes, 0) - outs = list(outs) - out = outs[rank] - dist.reduce_scatter(out, outs) - out = out.transpose(0, -2) - return out - - -def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, enable_ep): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp - if dp <= 1: +def _moe_gather_inputs(hidden_states, topk_weights, topk_ids, group: Optional[dist.ProcessGroup] = None): + dist_config = get_dist_manager().current_config() + tp = dist_config.moe_tp + if tp == 1: return hidden_states, topk_weights, topk_ids - step_ctx = get_step_ctx_manager().current_context() - dp_meta = step_ctx.dp_meta - if not enable_ep: - if dist_ctx.tp == 1: - return hidden_states, topk_weights, topk_ids - tp_sizes = dp_meta.tp_sizes - hidden_states = _gather_input(hidden_states, tp_sizes) - topk_weights = _gather_input(topk_weights, tp_sizes) - topk_ids = _gather_input(topk_ids, tp_sizes) + tp_mode = dist_config.moe_tp_mode + if tp_mode == TPMode.DEFAULT: + return hidden_states, topk_weights, topk_ids + elif tp_mode == TPMode.DP_TP: + step_ctx = get_step_ctx_manager().current_context() + dp_meta = step_ctx.dp_meta + tp_sizes = dp_meta.moe_tp_sizes + hidden_states = dist.gather_by_tp_sizes(hidden_states, tp_sizes, group=group) + topk_weights = dist.gather_by_tp_sizes(topk_weights, tp_sizes, group=group) + topk_ids = dist.gather_by_tp_sizes(topk_ids, tp_sizes, group=group) else: raise RuntimeError('Not supported.') + return hidden_states, topk_weights, topk_ids -def _moe_reduce(ret, enable_ep): - dist_ctx = get_dist_manager().current_context() - dp = dist_ctx.dp - if dp > 1: +def _moe_reduce(ret, rank: int, tp_mode: TPMode, group: Optional[dist.ProcessGroup] = None): + dist_config = get_dist_manager().current_config() + if dist_config.moe_tp == 1: + return ret + + if tp_mode == TPMode.DEFAULT: + dist.all_reduce(ret, group=group) + return ret + elif tp_mode == TPMode.DP_TP: step_ctx = get_step_ctx_manager().current_context() dp_meta = step_ctx.dp_meta - if not enable_ep: - if dist_ctx.tp == 1: - return ret - tp_sizes = dp_meta.tp_sizes - ret = _reduce_scatter_input(ret, tp_sizes) - else: - raise RuntimeError('Not supported.') + tp_size = dp_meta.moe_tp_sizes + ret = dist.reduce_scatter_by_tp_sizes(ret, rank, tp_size, group=group) + return ret else: - dist.all_reduce(ret) - return ret + raise RuntimeError('Not supported.') class FusedMoE(nn.Module): @@ -230,13 +315,14 @@ def __init__(self, device = torch.device('cpu') if dtype is None: dtype = torch.float16 + self.init_tp_args(all_reduce, enable_ep) impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) self.impl = impl_builder.build(top_k, num_experts, renormalize) enable_ep = enable_ep and self.impl.support_ep() if enable_ep: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') expert_list = self.impl.ep_expert_list(world_size, rank) num_experts = len(expert_list) else: @@ -269,22 +355,68 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() - if world_size == 1: - all_reduce = False - self.all_reduce = all_reduce self.enable_ep = enable_ep self.act_func = act_func + def init_tp_args(self, all_reduce: bool, enable_ep: bool): + """Init tp args.""" + tp, tp_rank = get_tp_world_rank('moe') + dist_ctx = get_dist_manager().current_context() + dist_cfg = dist_ctx.dist_config + _, tp_mode = dist_cfg.get_tp_by_layer('moe') + tp = 1 if enable_ep else tp + tp_rank = 0 if enable_ep else tp_rank + all_reduce = all_reduce if tp > 1 else False + all_reduce = False if enable_ep else all_reduce + + self.tp = tp + self.tp_rank = tp_rank + self.tp_mode = tp_mode + self.all_reduce = all_reduce + self.tp_group = dist_ctx.moe_tp_group.gpu_group + self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group + + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def __gemm_func(hidden_states, topk_weights, topk_ids): + return self.gemm( + dict( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_idx=topk_ids, + moe_type=MoeType.Default, + ))['hidden_states'] + + self.forward_dptp = MoEForwardDPTP(__gemm_func) + def update_weights(self): """Update weights.""" gate_up_weights, down_weights = self.impl.update_weights(self.gate_up.weight, self.down.weight) self.gate_up.update_weight(gate_up_weights) self.down.update_weight(down_weights) - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): - hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, topk_weights, topk_ids, - self.enable_ep) + def gemm(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Gemm.""" + hidden_states = inputs['hidden_states'] + topk_weights = inputs['topk_weights'] + topk_ids = inputs['topk_idx'] + + ret = self.impl.forward(hidden_states, + topk_weights, + topk_ids, + self.gate_up.weight, + self.down.weight, + self.gate_up.bias, + self.down.bias, + self.expert_list, + act_func=self.act_func) + return dict(hidden_states=ret) + + def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): + hidden_states, topk_weights, topk_ids = _moe_gather_inputs(hidden_states, + topk_weights, + topk_ids, + group=self.gather_group) ret = self.impl.forward(hidden_states, topk_weights, @@ -296,9 +428,15 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ self.expert_list, act_func=self.act_func) if self.all_reduce: - ret = _moe_reduce(ret, self.enable_ep) + ret = _moe_reduce(ret, rank=self.tp_rank, tp_mode=self.tp_mode, group=self.tp_group) return ret + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): + """forward.""" + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self.forward_dptp.forward(hidden_states, topk_weights, topk_ids) + return self.forward_default(hidden_states, topk_weights, topk_ids) + class LinearWeightsW8A8(LinearWeights): """Fused moe linear w8a8 weights.""" @@ -342,7 +480,7 @@ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') if shard_id == 'gate': param_data = param.data[expert_id, :self.half_out] weight = loaded_weight.chunk(world_size, dim=0)[rank] @@ -383,7 +521,7 @@ def __init__(self, enable_ep = enable_ep and self.impl.support_ep() if enable_ep: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') expert_list = self.impl.ep_expert_list(world_size, rank) num_experts = len(expert_list) else: @@ -413,7 +551,7 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() + world_size, _ = get_tp_world_rank('moe') if world_size == 1: all_reduce = False self.all_reduce = all_reduce @@ -465,13 +603,14 @@ def __init__(self, device=device) weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) self.register_parameter('weight_scale_inv', weight_scale_inv) - self.weight._base_weight_loader = self.weight.weight_loader - self.weight.weight_loader = self.weight_loader_with_quant if self.ep: + self.weight._base_weight_loader = self.weight.weight_loader self.weight_scale_inv.weight_loader = self.weight_loader_scale_ep else: + self.weight._base_weight_loader = self.weight_loader_tp_blocked_fp8 self.weight_scale_inv.weight_loader = self.weight_loader_scale_tp + self.weight.weight_loader = self.weight_loader_with_quant def update_weight(self, weight: torch.Tensor, weight_scale_inv: torch.Tensor): """Update weight.""" @@ -490,22 +629,58 @@ def weight_loader_scale_ep(self, param: torch.nn.Parameter, loaded_weight: torch for expert_id in expert_ids: self.weight_loader_scale_tp(param, loaded_weight, expert_id, shard_id) + def _chunk_weight_tp(self, weight: torch.Tensor, dim: int, world_size: int, rank: int, align: int): + """Chunk with align.""" + split_size = _split_size(weight.size(dim), world_size, align) + return weight.split(split_size, dim=dim)[rank] + + def weight_loader_tp_blocked_fp8(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """Weight loader.""" + world_size, rank = get_tp_world_rank('moe') + if shard_id == 'gate': + param_data = param.data[expert_id, :self.half_out] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'up': + param_data = param.data[expert_id, self.half_out:] + weight = self._chunk_weight_tp(loaded_weight, + dim=0, + world_size=world_size, + rank=rank, + align=self.block_size) + elif shard_id == 'down': + param_data = param.data[expert_id] + # weight is not contiguous, chunk and copy in cpu is slow + weight = loaded_weight.to(param_data.device) + if weight.dim() > 1: + weight = self._chunk_weight_tp(weight, dim=1, world_size=world_size, rank=rank, align=self.block_size) + elif weight.dim() == 1 and rank != 0: + # bias with rank>0 should be 0 + weight = torch.zeros_like(weight) + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + def weight_loader_scale_tp(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int, shard_id: str): """Weight loader scale tp.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('moe') block_size = self.block_size half_out = self.half_out // block_size if shard_id == 'gate': param_data = param.data[expert_id, :half_out] - weight = loaded_weight.chunk(world_size, dim=0)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) elif shard_id == 'up': param_data = param.data[expert_id, half_out:] - weight = loaded_weight.chunk(world_size, dim=0)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=0, world_size=world_size, rank=rank, align=1) elif shard_id == 'down': param_data = param.data[expert_id] loaded_weight = loaded_weight.to(param_data.device) - weight = loaded_weight.chunk(world_size, dim=1)[rank] + weight = self._chunk_weight_tp(loaded_weight, dim=1, world_size=world_size, rank=rank, align=1) else: raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(weight) @@ -544,6 +719,7 @@ def __init__(self, device = torch.device('cpu') dtype = torch.float16 if dtype is None else dtype self.block_size = 128 + self.init_tp_args(all_reduce, enable_ep) dist_ctx = get_dist_manager().current_context() self.ep_size, rank = get_ep_world_rank() impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoEBlockedF8) @@ -562,7 +738,7 @@ def __init__(self, expert_list = self.impl.ep_expert_list(self.ep_size, rank) num_experts = len(expert_list) else: - hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim, align=self.block_size) expert_list = None self.expert_list = expert_list @@ -594,12 +770,46 @@ def __init__(self, self.num_experts = num_experts self.dtype = dtype self.device = device - world_size, _ = get_tp_world_rank() - if world_size == 1: - all_reduce = False - self.all_reduce = all_reduce self.act_func = act_func + @staticmethod + def _update_args(hidden_dim: int, ffn_dim: int, align: int): + """Update args.""" + world_size, rank = get_tp_world_rank('moe') + split_size = _split_size(ffn_dim, world_size, align) + ffn_dim = split_size[rank] + return hidden_dim, ffn_dim + + def init_tp_args(self, all_reduce: bool, enable_ep: bool): + """Init tp args.""" + tp, tp_rank = get_tp_world_rank('moe') + dist_ctx = get_dist_manager().current_context() + dist_cfg = dist_ctx.dist_config + _, tp_mode = dist_cfg.get_tp_by_layer('moe') + tp = 1 if enable_ep else tp + tp_rank = 0 if enable_ep else tp_rank + all_reduce = all_reduce if tp > 1 else False + all_reduce = False if enable_ep else all_reduce + + self.tp = tp + self.tp_rank = tp_rank + self.tp_mode = tp_mode + self.all_reduce = all_reduce + self.tp_group = dist_ctx.moe_tp_group.gpu_group + self.gather_group = dist_ctx.moe_tp_group.gpu_gather_group + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + + def __gemm_func(hidden_states, topk_weights, topk_ids): + return self.gemm( + dict( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_idx=topk_ids, + moe_type=MoeType.Default, + ))['hidden_states'] + + self.forward_dptp = MoEForwardDPTP(__gemm_func) + def update_weights(self): """Update weights.""" (gate_up_weights, down_weights, gate_up_scale, @@ -608,7 +818,7 @@ def update_weights(self): self.gate_up.update_weight(gate_up_weights, gate_up_scale) self.down.update_weight(down_weights, down_scale) - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + def forward_default(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): state = { 'hidden_states': hidden_states, 'topk_idx': topk_idx, @@ -620,6 +830,13 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ out_state = self.combine(gemm_state) return out_state['hidden_states'] + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_idx: torch.LongTensor): + + if self.tp > 1 and self.tp_mode == TPMode.DP_TP: + return self.forward_dptp.forward(hidden_states, topk_weights, topk_idx) + else: + return self.forward_default(hidden_states, topk_weights, topk_idx) + def before_dispatch(self, state: Dict): moe_type = state['moe_type'] if moe_type == MoeType.DSAsyncPrefill: @@ -682,8 +899,10 @@ def dispatch(self, state: Dict): else: recv_state['hook'] = hook else: # MoeType.Default - hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], state['topk_weights'], - state['topk_idx'], False) + hidden_states, topk_weights, topk_idx = _moe_gather_inputs(state['hidden_states'], + state['topk_weights'], + state['topk_idx'], + group=self.gather_group) recv_state = { 'hidden_states': hidden_states, 'topk_idx': topk_idx, @@ -769,7 +988,10 @@ def combine(self, state: Dict): out_state['hook'] = hook else: # MoeType.Default if self.all_reduce: - state['hidden_states'] = _moe_reduce(state['hidden_states'], False) + state['hidden_states'] = _moe_reduce(state['hidden_states'], + rank=self.tp_rank, + tp_mode=self.tp_mode, + group=self.tp_group) out_state = {'hidden_states': state['hidden_states'], 'moe_type': state['moe_type']} return out_state diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index 908624b3c8..f477467f0d 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -45,7 +45,7 @@ def __init__(self, builder = backend.get_layer_impl_builder(OpType.RMSNorm) if tp: - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') hidden_size = get_distribute_size(hidden_size, world_size, rank, align=align) self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) @@ -60,7 +60,7 @@ def __init__(self, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): """Weight loader.""" - world_size, rank = get_tp_world_rank() + world_size, rank = get_tp_world_rank('attn') loaded_weight = chunk_aligned(loaded_weight, world_size, 0, self.align)[rank] param.copy_(loaded_weight) diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 4598e02011..6d397916a7 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -4,9 +4,9 @@ from typing import Any, List, Optional import torch +import torch.distributed as dist from torch.profiler import record_function -import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.distributed import DistContext from lmdeploy.pytorch.engine.logits_process import SamplingInputs from lmdeploy.pytorch.messages import SchedulerSequence @@ -113,7 +113,8 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ @contextmanager def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: DistContext): """Broadcast next token ids and extra inputs.""" - tp_gpu_group = dist_ctx.tp_gpu_group - handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + tp_gpu_group = dist_ctx.attn_tp_group.gpu_group + rank = dist.get_global_rank(tp_gpu_group, 0) + handle = dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True) yield handle.wait() diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py index 935b5f1825..a1829b3e65 100644 --- a/lmdeploy/pytorch/strategies/dllm/model_agent.py +++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py @@ -5,9 +5,9 @@ import numpy as np import torch +import torch.distributed as dist from torch.profiler import record_function -import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch import consts from lmdeploy.pytorch.config import DLLMConfig from lmdeploy.pytorch.distributed import DistContext @@ -232,8 +232,9 @@ def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, ext @contextmanager def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext): """Broadcast next token ids and extra inputs.""" - tp_gpu_group = dist_ctx.tp_gpu_group - dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) - handle = extra_inputs.broadcast(src=0, group=tp_gpu_group, async_op=True) + tp_gpu_group = dist_ctx.attn_tp_group.gpu_group + rank = dist.get_global_rank(tp_gpu_group, 0) + dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True) + handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True) yield handle.wait() diff --git a/lmdeploy/pytorch/tools/utils.py b/lmdeploy/pytorch/tools/utils.py index 98d42772c0..899dfc42fa 100644 --- a/lmdeploy/pytorch/tools/utils.py +++ b/lmdeploy/pytorch/tools/utils.py @@ -192,3 +192,24 @@ def _print_meta(out: Response): print(colored('─' * (term_size), border_color, attrs=['dark'])) else: print(colored('━' * term_size, border_color)) + + +def visualize_chat_completions(outputs, enable_meta: bool = True): + """Visualize chat completions.""" + from openai.types.chat import ChatCompletion + + from lmdeploy.messages import Response + if isinstance(outputs, ChatCompletion): + outputs = [outputs] + + resps = [] + for out in outputs: + assert isinstance(out, ChatCompletion) + choice = out.choices[0] + resp = Response(text=choice.message.content, + input_token_len=out.usage.prompt_tokens, + generate_token_len=out.usage.completion_tokens, + finish_reason=choice.finish_reason) + resps.append(resp) + + return visualize_pipe_out(resps, enable_meta=enable_meta) diff --git a/lmdeploy/serve/openai/launch_server.py b/lmdeploy/serve/openai/launch_server.py index 94121b6a94..2d2fd56c3f 100644 --- a/lmdeploy/serve/openai/launch_server.py +++ b/lmdeploy/serve/openai/launch_server.py @@ -97,11 +97,10 @@ def launch_server(num_nodes: int, ep = backend_config.ep assert dp > 1, f'only support dp > 1, but give dp={dp}' assert tp > 1 or ep > 1, f'only support tp > 1 or ep > 1, but given tp={tp} ep={ep}' - assert tp <= dp, f'only support tp <= dp, but give tp={tp}, dp={dp}' - assert ep <= dp, f'only support ep <= dp, but give tp={ep}, dp={dp}' + num_devices = max(dp, tp, ep) dp_per_node = dp // num_nodes - tp_per_dp = 1 # each dp uses one rank + tp_per_dp = num_devices // dp http_or_https = 'https' if kwargs.get('ssl', False) else 'http' model_name = kwargs.get('model_name', None) if model_name is None: diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 4f88340593..11f0caec47 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -597,6 +597,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # Prefill prefill_request_dict = copy.deepcopy(request_dict) prefill_request_dict['max_tokens'] = 1 + prefill_request_dict['max_completion_tokens'] = 1 prefill_request_dict['stream'] = False prefill_request_dict['with_cache'] = True prefill_request_dict['preserve_cache'] = True