Skip to content

update twomicrobatch #3651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lmdeploy.pytorch.backends.selector import get_backend
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -150,11 +150,12 @@ def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, pas
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
if meta.padding_batch_size is None:
new_num_tokens = next_power_of_2(num_tokens)
else:
new_num_tokens = next_power_of_2(meta.padding_batch_size)
return (new_num_tokens, is_decoding)
return (new_num_tokens, is_decoding, enable_microbatch)

def __call__(self, **kwargs):
"""call."""
Expand Down
25 changes: 25 additions & 0 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
from contextlib import asynccontextmanager, contextmanager
from multiprocessing.reduction import ForkingPickler
from os import getenv
from typing import Any, Dict

import torch
Expand Down Expand Up @@ -274,6 +275,15 @@ def __init__(self,
self.cache_engine = None
self.profiler: AgentProfiler = None

# microbatch
self.enable_microbatch = self.dist_ctx.dist_config.enable_microbatch
self.enable_microbatch_prefill_batchsize_threshold = \
int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2))
self.enable_microbatch_prefill_token_threshold = \
int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2))
self.enable_microbatch_decode_batchsize_threshold = \
int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2))

@contextmanager
def all_context(self):
device_mgr = get_device_manager()
Expand Down Expand Up @@ -517,6 +527,17 @@ async def __prepare_dp():
# gather dp forward metadata
batch_size = inputs.seq_length.numel()
dp_forward_meta = [int(is_decoding), int(is_dummy), batch_size, int(sync_long_context)]
# check enable_microbatch
if self.enable_microbatch:
tokens_num = inputs.input_ids.numel()
if is_decoding:
enable_microbatch = batch_size >= \
self.enable_microbatch_decode_batchsize_threshold
else:
enable_microbatch = batch_size >= \
self.enable_microbatch_prefill_batchsize_threshold and \
tokens_num >= self.enable_microbatch_prefill_token_threshold
dp_forward_meta.append(int(enable_microbatch))
gathered_meta = DistGatherScalar(dp_forward_meta, dp, device='cuda')

yield
Expand All @@ -543,6 +564,10 @@ async def __prepare_dp():
sync_long_context = all_sync_flags.any()
logger.debug(f'sync_long_context={sync_long_context}')

# update if enable_microbatch
if self.enable_microbatch and gathered_meta[:, 4].all():
inputs.enable_microbatch = True

# update dp meta
inputs.build_dp_meta()
inputs = self.patched_model.update_inputs(inputs)
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class ModelInputs:
history_cross_length: torch.LongTensor = None
model_metas: List[Dict[str, Any]] = None
dp_meta: 'DPMeta' = None
enable_microbatch: bool = False

def update(self, input_ids: torch.LongTensor):
"""Update input ids."""
Expand Down Expand Up @@ -420,6 +421,7 @@ def new(
cross_seqlens=cross_seqlens,
cross_kv_seqlens=cross_kv_seqlens,
dp_meta=inputs.dp_meta,
enable_microbatch=inputs.enable_microbatch,
)

ret = get_backend().update_step_context(ret)
Expand Down
26 changes: 0 additions & 26 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,13 +1152,6 @@ def __init__(self,
dtype=dtype,
device=device)
self._load_buffers = dict()
self.enable_microbatch = get_dist_manager().current_context().dist_config.enable_microbatch
self.enable_microbatch_prefill_batchsize_threshold = \
int(getenv('ENABLE_MICROBATCH_PREFILL_BATCHSIZE_THRESHOLD', 2))
self.enable_microbatch_prefill_token_threshold = \
int(getenv('ENABLE_MICROBATCH_PREFILL_TOKEN_THRESHOLD', 2))
self.enable_microbatch_decode_batchsize_threshold = \
int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2))

def forward(
self,
Expand Down Expand Up @@ -1206,25 +1199,6 @@ def prepare_inputs_for_generation(
position_ids = context.position_ids
attn_metadata = context.attn_metadata

# twobatch or onebatch
if self.enable_microbatch:
batch_size = attn_metadata.q_start_loc.size(dim=0)
tokens = input_ids.numel()
if attn_metadata.is_decoding:
enable_microbatch = batch_size >= self.enable_microbatch_decode_batchsize_threshold
else:
enable_microbatch = batch_size >= self.enable_microbatch_prefill_batchsize_threshold and \
tokens >= self.enable_microbatch_prefill_token_threshold
if enable_microbatch:
disable_num = torch.tensor(0, dtype=torch.int32, device=input_ids.device)
else:
disable_num = torch.tensor(1, dtype=torch.int32, device=input_ids.device)
ep_group = get_dist_manager().current_context().ep_gpu_group
dist.all_reduce(disable_num, op=dist.ReduceOp.SUM, group=ep_group)
step_ctx = get_step_ctx_manager().current_context()
enable_microbatch = disable_num.item() == 0
step_ctx.enable_microbatch = enable_microbatch

return dict(
input_ids=input_ids,
position_ids=position_ids,
Expand Down