diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index d61dfbf61e..dc000fd014 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -6,7 +6,7 @@ from lmdeploy.pytorch.backends.selector import get_backend from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig -from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta from lmdeploy.utils import get_logger @@ -150,11 +150,12 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas is_decoding = context.is_decoding num_tokens = input_ids.numel() meta = self.get_meta() + enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch if meta.padding_batch_size is None: new_num_tokens = next_power_of_2(num_tokens) else: new_num_tokens = next_power_of_2(meta.padding_batch_size) - return (new_num_tokens, is_decoding) + return (new_num_tokens, is_decoding, enable_microbatch) def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 233425b37e..7408f9d554 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -4,6 +4,7 @@ import functools from contextlib import asynccontextmanager, contextmanager from multiprocessing.reduction import ForkingPickler +from os import getenv from typing import Any, Dict import torch @@ -274,6 +275,15 @@ def __init__(self, self.cache_engine = None self.profiler: AgentProfiler = None + # microbatch + self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch + self.enable_microbatch_prefill_batchsize_threshold = \ + int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2)) + self.enable_microbatch_prefill_token_threshold = \ + int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2)) + self.enable_microbatch_decode_batchsize_threshold = \ + int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -517,6 +527,17 @@ async def __prepare_dp(): # gather dp forward metadata batch_size = inputs.seq_length.numel() dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)] + # check enable_microbatch + if self.enable_microbatch: + tokens_num = inputs.input_ids.numel() + if is_decoding: + enable_microbatch = batch_size >= \ + self.enable_microbatch_decode_batchsize_threshold + else: + enable_microbatch = batch_size >= \ + self.enable_microbatch_prefill_batchsize_threshold and \ + tokens_num >= self.enable_microbatch_prefill_token_threshold + dp_forward_meta.append(int(enable_microbatch)) gathered_meta = DistGatherScalar(dp_forward_meta, dp, device='cuda') yield @@ -543,6 +564,10 @@ async def __prepare_dp(): sync_long_context = all_sync_flags.any() logger.debug(f'sync_long_context={sync_long_context}') + # update if enable_microbatch + if self.enable_microbatch and gathered_meta[:, 4].all(): + inputs.enable_microbatch = True + # update dp meta inputs.build_dp_meta() inputs = self.patched_model.update_inputs(inputs) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 0efce3a868..3b3190c16d 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -152,6 +152,7 @@ class ModelInputs: history_cross_length: torch.LongTensor = None model_metas: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None + enable_microbatch: bool = False def update(self, input_ids: torch.LongTensor): """Update input ids.""" @@ -420,6 +421,7 @@ def new( cross_seqlens=cross_seqlens, cross_kv_seqlens=cross_kv_seqlens, dp_meta=inputs.dp_meta, + enable_microbatch=inputs.enable_microbatch, ) ret = get_backend().update_step_context(ret) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 25142c1a3d..b140f17b28 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1152,13 +1152,6 @@ def __init__(self, dtype=dtype, device=device) self._load_buffers = dict() - self.enable_microbatch = get_dist_manager().current_context().dist_config.enable_microbatch - self.enable_microbatch_prefill_batchsize_threshold = \ - int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2)) - self.enable_microbatch_prefill_token_threshold = \ - int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2)) - self.enable_microbatch_decode_batchsize_threshold = \ - int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) def forward( self, @@ -1206,25 +1199,6 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata - # twobatch or onebatch - if self.enable_microbatch: - batch_size = attn_metadata.q_start_loc.size(dim=0) - tokens = input_ids.numel() - if attn_metadata.is_decoding: - enable_microbatch = batch_size >= self.enable_microbatch_decode_batchsize_threshold - else: - enable_microbatch = batch_size >= self.enable_microbatch_prefill_batchsize_threshold and \ - tokens >= self.enable_microbatch_prefill_token_threshold - if enable_microbatch: - disable_num = torch.tensor(0, dtype=torch.int32, device=input_ids.device) - else: - disable_num = torch.tensor(1, dtype=torch.int32, device=input_ids.device) - ep_group = get_dist_manager().current_context().ep_gpu_group - dist.all_reduce(disable_num, op=dist.ReduceOp.SUM, group=ep_group) - step_ctx = get_step_ctx_manager().current_context() - enable_microbatch = disable_num.item() == 0 - step_ctx.enable_microbatch = enable_microbatch - return dict( input_ids=input_ids, position_ids=position_ids,