diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index de1fba57c0..5ba3b74405 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -31,11 +31,14 @@ class MLCServeEngineConfig: # TODO(@sunggg): figure out better defaults use_staging_engine: bool = True max_num_batched_tokens: int = 4096 + max_num_seq: int = 256 + max_num_seq_per_request: Optional[int] = None # default to `max_num_seq / 4` min_decode_steps: int = 32 max_decode_steps: int = 48 init_timeout: int = 120 model_type: str = "tvm" # "tvm", "torch" num_shards: Optional[int] = None # Need to be specified for if model_type is "torch" + gpu_memory_utilization: float = 0.9 @classmethod def _from_json(config_cls, json_obj: Dict[Any, Any]): diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index a381d72f69..46d8d0c4e4 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -405,6 +405,8 @@ class EngineBase: model_artifact_config: ModelArtifactConfig max_context_length: int max_num_batched_tokens: int + max_num_seq: int + max_num_seq_per_request: int max_decode_steps: int min_decode_steps: int kv_cache_size: int @@ -426,6 +428,10 @@ def __init__(self, model_module: ModelModule): ), "max_context_length must not be zero" self.max_context_length = self.model_artifact_config.max_context_length self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens + self.max_num_seq = model_module.engine_config.max_num_seq + self.max_num_seq_per_request = model_module.engine_config.max_num_seq_per_request + if self.max_num_seq_per_request is None: + self.max_num_seq_per_request = self.max_num_seq // 4 self.max_decode_steps = min( self.cache_manager.get_kv_cache_size(), model_module.engine_config.max_decode_steps, @@ -592,6 +598,14 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: ) return None + current_num_seq = sum(len(s.generation_sequences) for s in self.current_batch.values()) + if current_num_seq + len(state.generation_sequences) > self.max_num_seq: + LOG.debug( + "Stop growing the batch due to max number of sequences.", + ) + return None + + self.queue.popleft() self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences) self.current_batch[state.request_id] = state diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 84e7ce8b28..5d402536d9 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -113,6 +113,12 @@ def add(self, request_states: list[RequestState]): "The prompt is too long for the given set of engine" " parameters." ) + elif state.num_sequences > self.max_num_seq_per_request: + self.cancelled_requests.append(state) + state.validation_err = ValidationError( + f"The number of sequences ({state.num_sequences}) is greater" + f"than the maximum allowed value ({self.max_num_seq_per_request})" + ) else: valid_states.append(state) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index d38edb88cd..2941591552 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -35,7 +35,7 @@ def get_num_cache_blocks( num_layers, num_kv_heads, head_size, - gpu_memory_utilization=0.9, # the default used by vllm + gpu_memory_utilization, ): cache_block_size = CacheManager.get_cache_block_size( block_size, num_layers, num_kv_heads, head_size diff --git a/serve/mlc_serve/model/torch_model.py b/serve/mlc_serve/model/torch_model.py index 04381cd3b4..91d4d74d35 100644 --- a/serve/mlc_serve/model/torch_model.py +++ b/serve/mlc_serve/model/torch_model.py @@ -168,6 +168,8 @@ def profile_and_init_cache( hf_config, num_shards, max_num_batched_tokens, + max_num_seq, + gpu_memory_utilization, ): num_kv_heads = hf_config.num_key_value_heads // num_shards num_hidden_layers = hf_config.num_hidden_layers @@ -177,7 +179,9 @@ def profile_and_init_cache( if max_num_batched_tokens > 0: LOG.info("Running memory profiling.") - seq_lens = [1] * max_num_batched_tokens + seq_len = max_num_batched_tokens // max_num_seq + seq_lens = [seq_len] * max_num_seq + seq_lens[-1] += max_num_batched_tokens % max_num_seq used_memory_bytes = profile_memory_usage( pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size ) @@ -187,6 +191,7 @@ def profile_and_init_cache( hf_config.num_hidden_layers, num_kv_heads, head_size, + gpu_memory_utilization, ) else: num_blocks = 500 @@ -423,6 +428,8 @@ def exposed_init_model( hf_config, num_shards, engine_config.max_num_batched_tokens, + engine_config.max_num_seq, + engine_config.gpu_memory_utilization, ) return num_blocks @@ -593,6 +600,8 @@ def __init__( hf_config, 1, engine_config.max_num_batched_tokens, + engine_config.max_num_seq, + engine_config.gpu_memory_utilization, ) self.model_rpc = None diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index baababe368..4b0d9ff228 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -588,7 +588,12 @@ def init_tvm_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") try: - seq_lens = [1] * engine_config.max_num_batched_tokens + max_num_seq = engine_config.max_num_seq + max_num_batched_tokens = engine_config.max_num_batched_tokens + seq_len = max_num_batched_tokens // max_num_seq + seq_lens = [seq_len] * max_num_seq + seq_lens[-1] += max_num_batched_tokens % max_num_seq + used_memory_bytes = model.profile_memory_usage(seq_lens) num_blocks = get_num_cache_blocks( used_memory_bytes, @@ -596,6 +601,7 @@ def init_tvm_model( model_artifact_config.num_hidden_layers, num_kv_heads, head_size, + engine_config.gpu_memory_utilization, ) except tvm.error.InternalError: raise RuntimeError( diff --git a/serve/mlc_serve/utils.py b/serve/mlc_serve/utils.py index 578f5893b3..e1d9ca80aa 100644 --- a/serve/mlc_serve/utils.py +++ b/serve/mlc_serve/utils.py @@ -30,8 +30,10 @@ def get_default_mlc_serve_argparser(description="", allow_override=False): parser.add_argument("--use-sync-engine", action="store_true") parser.add_argument("--num-sequences-to-sample", type=int, default=1) parser.add_argument("--max-num-batched-tokens", type=int, default=4096) + parser.add_argument("--max-num-seq", type=int, default=256) parser.add_argument("--min-decode-steps", type=int, default=32) parser.add_argument("--max-decode-steps", type=int, default=56) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--debug-logging", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models @@ -73,10 +75,12 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE { "use_staging_engine": args.use_staging_engine, "max_num_batched_tokens": args.max_num_batched_tokens, + "max_num_seq": args.max_num_seq, "min_decode_steps": args.min_decode_steps, "max_decode_steps": args.max_decode_steps, "model_type": model_type, "num_shards": num_shards, + "gpu_memory_utilization": args.gpu_memory_utilization, } )