diff --git a/serve/benchmarks/benchmark_latency.py b/serve/benchmarks/benchmark_latency.py index 84c377b710..0c42606691 100644 --- a/serve/benchmarks/benchmark_latency.py +++ b/serve/benchmarks/benchmark_latency.py @@ -34,8 +34,9 @@ def create_request(request_id): frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], - logprobs = args.sampling_setting["logprobs"], - top_logprobs = args.sampling_setting["top_logprobs"], + logprobs=args.sampling_setting["logprobs"], + top_logprobs=args.sampling_setting["top_logprobs"], + json_schema=args.sampling_setting["json_schema"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 4f82d80dd7..3cb6958e47 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -139,8 +139,9 @@ def run_mlc(engine, requests, args) -> float: frequency_penalty=args.sampling_setting["frequency_penalty"], presence_penalty=args.sampling_setting["presence_penalty"], logit_bias=args.sampling_setting["logit_bias"], - logprobs = args.sampling_setting["logprobs"], - top_logprobs = args.sampling_setting["top_logprobs"], + logprobs=args.sampling_setting["logprobs"], + top_logprobs=args.sampling_setting["top_logprobs"], + json_schema=args.sampling_setting["json_schema"], ), stopping_criteria=StoppingCriteria( max_tokens=args.num_output_tokens, stop_sequences=None diff --git a/serve/benchmarks/utils.py b/serve/benchmarks/utils.py index 4507dc0dd2..3c893c660a 100644 --- a/serve/benchmarks/utils.py +++ b/serve/benchmarks/utils.py @@ -1,4 +1,9 @@ """Utils for benchmark scripts""" +from pydantic import BaseModel + + +class Output(BaseModel): + answer: str def add_sampling_flags(parser): @@ -17,6 +22,11 @@ def add_sampling_flags(parser): action="store_true", help="Apply top-p and top-k.", ) + parser.add_argument( + "--apply-json-mode", + action="store_true", + help="Apply json mode.", + ) parser.add_argument( "--apply-all-sampling-params", action="store_true", @@ -26,13 +36,13 @@ def add_sampling_flags(parser): "--logprobs", action="store_true", default=False, - help="Switch on logprobs output" + help="Switch on logprobs output", ) parser.add_argument( "--top-logprobs", type=int, default=5, - help="Number of top logprobs to output, limited by 5. Works only with logprobs true." + help="Number of top logprobs to output, limited by 5. Works only with logprobs true.", ) @@ -47,12 +57,14 @@ def postproc_sampling_args(args): "top_k": -1, "logprobs": False, "top_logprobs": 5, + "json_schema": None, } if args.apply_all_sampling_params: args.apply_penalties = True args.apply_logit_bias = True args.apply_top_p_top_k = True + args.apply_json_mode = True if args.apply_penalties: args.sampling_setting["presence_penalty"] = 0.7 @@ -69,3 +81,6 @@ def postproc_sampling_args(args): if args.logprobs: args.sampling_setting["logprobs"] = True args.sampling_setting["top_logprobs"] = args.top_logprobs + + if args.apply_json_mode: + args.sampling_setting["json_schema"] = Output.model_json_schema() diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index ae4f40a32b..1ac19de062 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -71,7 +71,8 @@ def _get_sampling_params( if request.logprobs: sampling_params.top_logprobs = request.top_logprobs sampling_params.logprobs = request.logprobs - + if request.response_format and request.response_format.type == "json_object": + sampling_params.json_schema = request.response_format.response_schema sampling_params.vocab_size = model_artifact_config.vocab_size return sampling_params diff --git a/serve/mlc_serve/api/protocol.py b/serve/mlc_serve/api/protocol.py index 4f42f7233e..54f687ab73 100644 --- a/serve/mlc_serve/api/protocol.py +++ b/serve/mlc_serve/api/protocol.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/vllm-project/vllm/blob/acbed3ef40f015fcf64460e629813922fab90380/vllm/entrypoints/openai/protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, Any from pydantic import BaseModel, Field @@ -58,9 +58,16 @@ class ChatMessage(BaseModel): content: str +class ChatResponseFormat(BaseModel): + type: str + response_schema: Optional[Dict[str, Any]] = Field(None, alias="schema") + + class ChatCompletionRequest(BaseModel): model: str - messages: Union[str, List[ChatMessage]] # according to openai chat completion spec, here should be only a list of ChatMessage + messages: Union[ + str, List[ChatMessage] + ] # according to openai chat completion spec, here should be only a list of ChatMessage max_tokens: Optional[int] = None temperature: float = 1.0 top_p: float = 1.0 @@ -75,6 +82,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: Optional[bool] = False logprobs: bool = False top_logprobs: int = 0 + response_format: Optional[ChatResponseFormat] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/serve/mlc_serve/engine/constrained_sampling.py b/serve/mlc_serve/engine/constrained_sampling.py new file mode 100644 index 0000000000..9e0f3e6270 --- /dev/null +++ b/serve/mlc_serve/engine/constrained_sampling.py @@ -0,0 +1,94 @@ +import json +import math +from collections import defaultdict +from typing import DefaultDict, List + +import torch + +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_object +from .base import SequenceId + +class RegexLogitsProcessor: + def __init__(self, regex_string, tokenizer): + """Compile the FSM that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + An instance of `tokenizer` + + """ + tokenizer = self.adapt_tokenizer(tokenizer) + + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm + self.fsm_state: DefaultDict[SequenceId, int] = defaultdict(int) + + def __call__( + self, seq_id: SequenceId, input_ids: List[int], scores: torch.Tensor + ) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + if len(input_ids) == 0: # Initialize the fsm states + self.fsm_state = defaultdict(int) + else: + last_token = input_ids[-1] + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[seq_id], last_token + ) + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + def adapt_tokenizer(self, tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class JSONLogitsProcessor(RegexLogitsProcessor): + def __init__(self, schema, tokenizer): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate + tokenizer + An instance of `tokenizer` + + """ + if isinstance(schema, dict): + schema = json.dumps(schema) + regex_string = build_regex_from_object(schema) + super().__init__(regex_string, tokenizer) + diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 96e17eeefc..acf1faa90d 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -32,6 +32,7 @@ ) from ..model.base import ModelArtifactConfig from ..openai_logprob_protocol import LogprobsContent, TopLogprobs +from .constrained_sampling import JSONLogitsProcessor LOG = structlog.stdlib.get_logger(__name__) @@ -240,7 +241,9 @@ def prepare_output( def get_requests_to_process( - current_states: list[RequestState], cache_manager: KVCacheManager + current_states: list[RequestState], + cache_manager: KVCacheManager, + tokenizer: TokenizerP, ) -> Tuple[list[RequestType], bool, int]: requests: list[RequestType] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports @@ -289,6 +292,12 @@ def get_requests_to_process( # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in # Prometheus metric? elif not state.is_prefilled: + # `JSONLogitsProcessor` needs to be created only once. + if state.sampling_params.json_schema is not None: + state.sampling_params.logits_processor = JSONLogitsProcessor( + state.sampling_params.json_schema, tokenizer._tokenizer + ) + if ( state.num_sequences == 1 and state.generation_sequences[0].generated_token_ids diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 0e091e0c4a..867616d0d5 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -2,7 +2,7 @@ Required interfaces for the actual inference capability in InferenceEngine. """ from dataclasses import dataclass -from typing import Optional, Protocol, Union, List, Sequence +from typing import Optional, Protocol, Union, List, Sequence, Any from .base import ( ChatMessage, @@ -168,6 +168,7 @@ def generate( class Tokenizer(Protocol): + _tokenizer: Any eos_token_id: int skip_special_tokens: bool all_special_ids: List[int] diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index 6344b7ee15..ae96df7709 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import IntEnum from functools import cached_property -from typing import Dict, Optional +from typing import Dict, Optional, Any _SAMPLING_EPS = 1e-5 LOGPROB_TOP_K_MAX = 5 @@ -73,6 +73,8 @@ class SamplingParams: # Currently, it is unclear what is the best way to fetch this info and # check in `_verify_args` without this field. Follow-up when we have a better idea. vocab_size = 32000 + json_schema: Optional[Dict[str, Any]] = None + logits_processor: Optional[Any] = None def __post_init__(self): if self.logit_bias: diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index c73c6160c2..ad5e3089aa 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -329,7 +329,7 @@ def _adjust_batch(self): def _get_requests_to_process(self): requests, is_prompt_batch, token_counts = get_requests_to_process( - self.current_batch.values(), self.cache_manager + self.current_batch.values(), self.cache_manager, self.tokenizer ) if is_prompt_batch: diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 0baa8b50bf..a0e7194b22 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -143,7 +143,7 @@ def step(self) -> InferenceStepResult: return InferenceStepResult(outputs) requests, _, _ = get_requests_to_process( - list(self.current_batch.values()), self.cache_manager + list(self.current_batch.values()), self.cache_manager, self.tokenizer ) results = self.text_generator.generate(requests, self.cache_manager.get_cache()) logger.debug("Finished text generation.") diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 1dda0cee94..d5a010e4e5 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -14,6 +14,7 @@ ) from ..engine.model_module import ( PrefillRequest, + DecodeRequest, EvalMultiQueryRequest, RequestType, TextGenerationResult, @@ -97,6 +98,17 @@ def sample_from_logits( # synchronization point for sampling tensors # wait until all the tensors are loaded on GPU torch.cuda.current_stream().wait_stream(copy_stream) + + # Logit processing for constraint sampling e.g., JSON Mode + for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)): + if request.sampling_params.logits_processor is not None: + cs_input_ids = ( + request.token_ids if isinstance(request, DecodeRequest) else [] + ) + logits[i] = request.sampling_params.logits_processor( + sequence_id, cs_input_ids, logits[i] + ) + logits = adjust_logits(logits, sampling_metadata, vocab_size) outputs: List[TextGenerationResult] = [] diff --git a/serve/mlc_serve/model/tokenizer.py b/serve/mlc_serve/model/tokenizer.py index 23e0c044a5..68bb978110 100644 --- a/serve/mlc_serve/model/tokenizer.py +++ b/serve/mlc_serve/model/tokenizer.py @@ -49,6 +49,7 @@ class HfTokenizerModule: def __init__(self, model_artifact_path: Path): hf_tokenizer = AutoTokenizer.from_pretrained( model_artifact_path.joinpath("model"), + revision=None, tokenizer_revision=None, trust_remote_code=False, ) self.tokenizer = Tokenizer(hf_tokenizer) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index c5970ace6e..a69e98118e 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -504,7 +504,7 @@ def init_tvm_model( except tvm.error.InternalError: raise RuntimeError( f"Memory profiling failed with max_num_batched_tokens = " - "{engine_config.max_num_batched_tokens}." + "{engine_config.max_num_batched_tokens}." ) else: num_blocks = 500 diff --git a/serve/pyproject.toml b/serve/pyproject.toml index 4152805c5a..f79af292c1 100644 --- a/serve/pyproject.toml +++ b/serve/pyproject.toml @@ -9,11 +9,14 @@ python = ">=3.9" fastapi = ">=0.103.1" pydantic = ">=1.8.0" prometheus-client = ">=0.18.0" +outlines = "0.0.23" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" httpx_sse = "^0.3.1" pytest-timeout = "^2.2.0" +cuda-python = "12.3.0" +pandas = "2.2.0" [tool.setuptools] packages = ["mlc_serve"] diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index e3c1bee72f..b589d967a0 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -1,3 +1,4 @@ +import json from mlc_serve.engine import ( Request, ChatMessage, @@ -13,6 +14,8 @@ from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args import random +from pydantic import BaseModel +from typing import List def create_engine( @@ -56,10 +59,11 @@ def create_request( pre_pen, max_tokens, stop, - ignore_eos, + ignore_eos=False, top_logprobs=0, logprobs=False, logit_bias=None, + json_schema=None, ): return Request( request_id=str(idx), @@ -71,6 +75,7 @@ def create_request( logit_bias=logit_bias, logprobs=logprobs, top_logprobs=top_logprobs, + json_schema=json_schema, ), stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop), debug_options=DebugOptions(ignore_eos=ignore_eos), @@ -334,6 +339,96 @@ def _test_logprobs_mixed_requests( generated[int(res.request_id)] += seq.delta +# These three models are used in _test_json_mode +class France(BaseModel): + capital: str + + +class Snow(BaseModel): + color: str + + +class SnowList(BaseModel): + snow: List[Snow] + + +def _test_json_mode( + engine, +): + requests = [ + # test France schema + create_request( + idx=str(0), + prompt="what is the capital of France?", + temp=0, + freq_pen=0, + pre_pen=0, + max_tokens=30, + stop=None, + ignore_eos=False, + json_schema=France.model_json_schema(), + ), + # test with no JSON schema + create_request( + idx=str(1), + prompt="Hello", + temp=0, + freq_pen=0, + pre_pen=0, + max_tokens=30, + stop=None, + ignore_eos=False, + ), + # test Snow schema + create_request( + idx=str(2), + prompt="what is the color of the snow?", + temp=0, + freq_pen=0, + pre_pen=0, + max_tokens=30, + stop=None, + ignore_eos=False, + json_schema=Snow.model_json_schema(), + ), + # test SnowList schema (nested structure) + create_request( + idx=str(3), + prompt="Quick Facts About Snow | National Snow and Ice Data Center When light reflects off it, snow appears white. The many sides of a snowflake scatter light, diffusing the color spectrum in many directions. Snow can look dark when dust, or pollution, cover it. Fresh-water algae that loves snow can turn it into other colors like orange, blue, or watermelon pink. List the colors of snow.", + temp=0, + freq_pen=0, + pre_pen=0, + max_tokens=256, + stop=None, + ignore_eos=False, + json_schema=SnowList.model_json_schema(), + ), + ] + num_requests = len(requests) + engine.add(requests) + + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + + if not seq.is_finished: + generated[int(res.request_id)] += seq.delta + + for i, out_text in enumerate(generated): + if i == 0: + France.model_validate(json.loads(out_text)) + elif i == 1: + assert isinstance(out_text, str) + elif i == 2: + Snow.model_validate(json.loads(out_text)) + else: + SnowList.model_validate(json.loads(out_text)) + + if __name__ == "__main__": parser = get_default_mlc_serve_argparser("test engine with samplers") args = parser.parse_args() @@ -350,6 +445,7 @@ def _test_logprobs_mixed_requests( # _test_stop(staging_engine) _test_logprobs(staging_engine) _test_logprobs_mixed_requests(staging_engine) + _test_json_mode(staging_engine) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(staging_engine) @@ -364,6 +460,7 @@ def _test_logprobs_mixed_requests( _test_stop(sync_engine) _test_logprobs(sync_engine) _test_logprobs_mixed_requests(sync_engine) + _test_json_mode(sync_engine) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. # _test_max_context_length(sync_engine)