Skip to content

Commit de59d58

Browse files
committed
feat: replace outlines with xgrammar in pytorch engine
1 parent c45deea commit de59d58

File tree

13 files changed

+83
-179
lines changed

13 files changed

+83
-179
lines changed

.github/workflows/unit-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ jobs:
6464
python3 -m pip install /root/packages/cu118/flash_attn-*.whl
6565
python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt
6666
python3 -m pip install -e .
67-
python3 -m pip install -U 'numpy<2.0'
6867
- name: Check env
6968
run: |
7069
python3 -m pip list

docker/prepare_wheel.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ if [[ ${PYTHON_VERSION} = "3.13" ]]; then
1717

1818
pip install setuptools_rust
1919
pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/google/[email protected]#subdirectory=python"
20-
pip wheel -v --no-build-isolation --no-deps -w /wheels --use-deprecated=legacy-resolver outlines_core==0.1.26
2120
fi
2221

2322
if [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then
Lines changed: 52 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,87 @@
1-
# Copyright 2024- the Outlines developers
2-
# This file is adapted from
3-
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
4-
#
5-
# Licensed under the Apache License, Version 2.0 (the "License");
6-
# you may not use this file except in compliance with the License.
7-
# You may obtain a copy of the License at
8-
9-
# http://www.apache.org/licenses/LICENSE-2.0
10-
1+
# Copyright (c) OpenMMLab. All rights reserved.
112
import copy
12-
import math
13-
# Unless required by applicable law or agreed to in writing, software
14-
# distributed under the License is distributed on an "AS IS" BASIS,
15-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16-
# See the License for the specific language governing permissions and
17-
# limitations under the License.
18-
from collections import defaultdict
3+
import json
4+
import logging
195
from functools import lru_cache
20-
from typing import DefaultDict, Dict, List, Union
6+
from typing import Optional
217

228
import torch
23-
from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, Write
24-
from outlines.fsm.json_schema import build_regex_from_schema
25-
from pydantic import BaseModel
9+
import xgrammar as xgr
2610
from transformers import PreTrainedTokenizerBase
2711

12+
logger = logging.getLogger('guided_process')
2813

29-
class BaseLogitsProcessor:
30-
31-
def init_state(self):
32-
"""Initialize the FSM states."""
33-
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
34-
35-
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
36-
"""Use the FSM to bias the logits before sampling the next token."""
37-
38-
seq_id = hash(tuple(input_ids))
39-
40-
if len(input_ids) == 0:
41-
self.init_state()
42-
else:
43-
last_token = input_ids[-1]
44-
last_seq_id = hash(tuple(input_ids[:-1]))
45-
self.fsm_state[seq_id] = self.fsm.get_next_state(state=self.fsm_state[last_seq_id], token_id=last_token)
46-
47-
instruction = self.fsm.get_next_instruction(self.fsm_state[seq_id])
4814

49-
if type(instruction) == Generate:
50-
allowed_tokens = instruction.tokens
51-
elif type(instruction) == Write:
52-
# TODO: support fast forward tokens
53-
allowed_tokens = [instruction.tokens[0]]
54-
else:
55-
raise TypeError(f'Unsupported instruction type {type(instruction)}')
15+
class BaseLogitsProcessor:
16+
"""Base logits processor that uses xgrammar matcher for guided decoding."""
5617

57-
mask = torch.full((scores.shape[-1], ), -math.inf, device=scores.device)
58-
mask[allowed_tokens] = 0
59-
scores.add_(mask)
18+
def __init__(self, compiled_grammar: xgr.CompiledGrammar, tokenizer_info: xgr.TokenizerInfo):
19+
self.matcher = xgr.GrammarMatcher(compiled_grammar, terminate_without_stop_token=True)
20+
self.token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
6021

22+
def process(self, scores: torch.Tensor) -> torch.Tensor:
23+
"""Apply grammar constraints to logits before sampling the next
24+
token."""
25+
self.matcher.fill_next_token_bitmask(self.token_bitmask)
26+
xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device))
6127
return scores
6228

63-
def adapt_tokenizer(self, tokenizer):
64-
"""Adapt tokenizer to use to compile the FSM.
29+
def accept(self, token_id: int) -> bool:
30+
"""Update matcher state after a token is generated."""
31+
return self.matcher.accept_token(token_id)
6532

66-
The API of Outlines tokenizers is slightly different to that of `transformers`. In addition we need to handle
67-
the missing spaces to Llama's tokenizer to be able to compile FSMs for this model.
68-
"""
69-
from outlines.integrations.utils import adapt_tokenizer
70-
tokenizer = adapt_tokenizer(tokenizer)
71-
# vocab size greater than logits shape because of '[UNUSED_TOKEN_...]'
72-
if hasattr(tokenizer, '_tokenizer'):
73-
tokenizer.vocabulary = tokenizer._tokenizer.get_vocab(with_added_tokens=False)
74-
return tokenizer
33+
def reset(self):
34+
"""Reset matcher state for next generation."""
35+
self.matcher.reset()
7536

7637

7738
class RegexLogitsProcessor(BaseLogitsProcessor):
39+
"""Regex-guided logits processor using xgrammar."""
7840

79-
def __init__(self, regex_string: str, tokenizer):
80-
"""Compile the FSM that drives the regex-structured generation.
81-
82-
Args:
83-
regex_string: A string that represents a regular expression
84-
tokenizer: The model's tokenizer
85-
"""
86-
tokenizer = self.adapt_tokenizer(copy.deepcopy(tokenizer))
87-
fsm = RegexGuide(regex_string, tokenizer)
88-
self.fsm = fsm
89-
90-
91-
class JSONLogitsProcessor(RegexLogitsProcessor):
92-
93-
def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer):
94-
"""Compile the FSM that drives the JSON-guided generation.
95-
96-
Args:
97-
schema: A str schema that encodes the structure we want the model
98-
to generate
99-
tokenizer: The model's tokenizer
100-
"""
101-
regex_string = build_regex_from_schema(schema)
102-
super().__init__(regex_string, tokenizer)
103-
41+
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None):
42+
tokenizer = copy.deepcopy(tokenizer)
43+
if vocab_size_padded is None:
44+
vocab_size_padded = tokenizer.vocab_size
10445

105-
class CFGLogitsProcessor(BaseLogitsProcessor):
46+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded)
10647

107-
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
108-
"""Compile the FSM that drives the context free grammar generation.
48+
compiler = xgr.GrammarCompiler(tokenizer_info)
49+
compiled = compiler.compile_regex_grammar(regex_string)
10950

110-
Parameters
111-
----------
112-
cfg
113-
A string that represents a context-free grammar
114-
tokenizer
115-
The model's tokenizer
116-
"""
117-
tokenizer = self.adapt_tokenizer(tokenizer)
118-
fsm = CFGGuide(cfg, tokenizer)
119-
self.fsm = fsm
51+
super().__init__(compiled, tokenizer_info)
12052

12153

122-
# copied from https://github.com/vllm-project/vllm/blob/a7f65c2be93f491771aca31106f790bf381c0bad/vllm/model_executor/guided_decoding/outlines_decoding.py#L31 # noqa
123-
JSON_GRAMMAR = r"""
124-
?start: object | array
54+
class JSONLogitsProcessor(BaseLogitsProcessor):
55+
"""JSON-schema guided logits processor using xgrammar."""
12556

126-
?value: object
127-
| array
128-
| UNESCAPED_STRING
129-
| SIGNED_NUMBER -> number
130-
| "true" -> true
131-
| "false" -> false
132-
| "null" -> null
57+
def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None):
58+
tokenizer = copy.deepcopy(tokenizer)
59+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded)
60+
if vocab_size_padded is None:
61+
vocab_size_padded = tokenizer.vocab_size
13362

134-
array : "[" [value ("," value)*] "]"
135-
object : "{" [pair ("," pair)*] "}"
136-
pair : UNESCAPED_STRING ":" value
63+
compiler = xgr.GrammarCompiler(tokenizer_info)
64+
if isinstance(schema, str):
65+
schema = json.loads(schema)
13766

138-
%import common.UNESCAPED_STRING
139-
%import common.SIGNED_NUMBER
140-
%import common.WS
67+
assert isinstance(schema, dict)
68+
compiled = compiler.compile_json_schema(schema)
14169

142-
%ignore WS
143-
"""
70+
super().__init__(compiled, tokenizer_info)
14471

14572

14673
@lru_cache(maxsize=32)
147-
def _get_guided_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, type: str):
74+
def _get_guided_logits_processor(guide: str,
75+
tokenizer: PreTrainedTokenizerBase,
76+
type: str,
77+
vocab_size_padded: Optional[int] = None):
14878
try:
149-
if type == 'json_object':
150-
return CFGLogitsProcessor(guide, tokenizer)
151-
elif type == 'json_schema':
152-
return JSONLogitsProcessor(guide, tokenizer)
79+
if type == 'json_schema':
80+
return JSONLogitsProcessor(guide, tokenizer, vocab_size_padded)
15381
elif type == 'regex_schema':
154-
return RegexLogitsProcessor(guide, tokenizer)
82+
return RegexLogitsProcessor(guide, tokenizer, vocab_size_padded)
15583
else:
15684
return None
15785
except Exception as e:
158-
from lmdeploy.utils import get_logger
159-
logger = get_logger('lmdeploy')
16086
logger.error(e)
161-
return None
87+
raise

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,35 +78,30 @@ def _multinomial_sampling(scores: torch.Tensor,
7878
return multinomial_sampling(scores, seeds, offsets, indices)
7979

8080

81-
def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided_input_ids: Optional[torch.Tensor],
82-
tokenizer: object):
83-
if guided_input_ids is None:
84-
return scores
85-
for i in range(len(response_formats)):
86-
_format = response_formats[i]
81+
def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, vocab_size_padded: int):
82+
processors = {}
83+
for i, _format in enumerate(response_formats):
8784
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
8885
if _format['type'] == 'json_schema':
8986
schema = _format['json_schema']
9087
if isinstance(schema, Dict):
9188
for key in ['json_schema', 'schema']:
9289
if key in schema:
9390
schema = json.dumps(schema[key], ensure_ascii=False)
94-
elif schema is None:
95-
from .guided_process import JSON_GRAMMAR
96-
schema = JSON_GRAMMAR
97-
elif isinstance(schema, str):
91+
92+
if not isinstance(schema, str):
9893
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
9994
'either a dictionary or a string that contains the'
10095
' JSON Schema specification')
10196
elif _format['type'] == 'regex_schema':
10297
schema = _format.get('regex_schema', '')
10398
else:
10499
raise ValueError(f"unsupported format type: {_format['type']}")
100+
105101
from .guided_process import _get_guided_logits_processor
106-
processor = _get_guided_logits_processor(schema, tokenizer, _format['type'])
107-
if processor:
108-
scores[i] = processor(guided_input_ids[i].tolist(), scores[i])
109-
return scores
102+
processors[i] = _get_guided_logits_processor(schema, tokenizer, _format['type'], vocab_size_padded)
103+
104+
return processors
110105

111106

112107
SeqList = List[SchedulerSequence]
@@ -131,7 +126,6 @@ class SamplingInputs:
131126
logits_processors: List[List[LogitsProcessor]] = None
132127
max_num_logprobs: Optional[int] = None
133128
all_ids: Optional[torch.Tensor] = None
134-
guided_input_ids: Optional[torch.Tensor] = None
135129
num_ignore_eos: torch.Tensor = None
136130
batch_size: int = 0
137131

@@ -169,6 +163,8 @@ def __init__(self,
169163
self.tokenizer = tokenizer
170164
self.sampling_vocab_size = sampling_vocab_size
171165
self.logprobs_mode = logprobs_mode
166+
self.guided_processors = _get_guided_processors(sampling_inputs.response_formats, tokenizer,
167+
sampling_vocab_size)
172168

173169
async def _wait_stream_once(self):
174170
"""Wait stream once."""
@@ -205,9 +201,12 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
205201

206202
sampling_inputs = self.sampling_inputs
207203
all_ids = sampling_inputs.all_ids
208-
guided_input_ids = sampling_inputs.guided_input_ids
209-
210204
custom_logits_processors = self.sampling_inputs.logits_processors
205+
if self.guided_processors:
206+
await self._wait_stream_once()
207+
for i, processor in self.guided_processors.items():
208+
scores[i] = processor.process(scores[i])
209+
211210
if any(custom_logits_processors):
212211
await self._wait_stream_once()
213212
scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores)
@@ -232,9 +231,6 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
232231
stop_mask = torch.where(ignore_eos[:, None], stop_mask, False)
233232
scores = _process_bad_words_(scores, stop_words, stop_mask)
234233

235-
if guided_input_ids is not None:
236-
await self._wait_stream_once()
237-
scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer)
238234
return scores, logprobs
239235

240236
@torch.inference_mode()
@@ -272,15 +268,21 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
272268
logits = logits[..., :self.sampling_vocab_size]
273269

274270
if sampling_inputs.max_top_k == 1:
275-
return logits.argmax(-1)
271+
result = logits.argmax(-1)
276272
else:
277273
# sort logits is too slow. and we only need topk logits
278274
max_topk = sampling_inputs.max_top_k
279275
if max_topk <= 0:
280276
scores, indices = logits.sort(1, descending=True)
281277
else:
282278
scores, indices = logits.topk(max_topk, dim=1)
283-
return __random_sampling(scores, indices)
279+
result = __random_sampling(scores, indices)
280+
281+
if self.guided_processors:
282+
for i, processor in self.guided_processors.items():
283+
processor.accept(result[i])
284+
285+
return result
284286

285287
@torch.inference_mode()
286288
def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor):

lmdeploy/pytorch/strategies/ar/model_agent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids:
7272
if all_ids is not None:
7373
sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)
7474

75-
guided_input_ids = sampling_inputs.guided_input_ids
76-
if guided_input_ids is not None:
77-
sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1)
78-
7975
return sampling_inputs
8076

8177
def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria:

lmdeploy/pytorch/strategies/ar/sampling.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,6 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs)
2727
return output
2828

2929

30-
def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'):
31-
"""Gather input ids for guided decode."""
32-
if not any(sampling_inputs.response_formats or ()):
33-
return None
34-
batch = len(seqs)
35-
max_len = max(seq.num_new_tokens for seq in seqs)
36-
output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
37-
for idx, seq in enumerate(seqs):
38-
h_len = seq.num_new_tokens
39-
if h_len == 0:
40-
continue
41-
h_ids = torch.from_numpy(seq.generated_ids)
42-
output[idx, -h_len:] = h_ids
43-
return output
44-
45-
4630
def _get_num_ignore_eos(seqs: SeqList):
4731
"""Get num ignore eos."""
4832
ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]
@@ -186,6 +170,5 @@ def __get_bad_words(bad_words):
186170

187171
pad_token_id = self.pad_token_id
188172
sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input)
189-
sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input)
190173
sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs)
191174
return sampling_input

lmdeploy/pytorch/strategies/dllm/sampling.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
3535
'random_seeds',
3636
'random_offsets',
3737
'all_ids',
38-
'guided_input_ids',
3938
'num_ignore_eos',
4039
]
4140
for name in update_attr_names:

0 commit comments

Comments
 (0)