Skip to content

Commit befa9e5

Browse files
committed
Add MagicMTP(block verify) and Triton optimization
Signed-off-by: chenaoxuan <[email protected]>
1 parent 941d54a commit befa9e5

File tree

2 files changed

+299
-42
lines changed

2 files changed

+299
-42
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ ignore_missing_imports = True
1515
[mypy-lm_eval.*]
1616
ignore_missing_imports = True
1717

18+
[mypy-triton.*]
19+
ignore_missing_imports = True
20+
1821
[mypy-msprobe.*]
1922
ignore_missing_imports = True
2023
allow_untyped_imports = True

vllm_ascend/sample/rejection_sampler.py

Lines changed: 296 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,150 @@
1010
generate_uniform_probs)
1111
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1212

13+
try:
14+
import triton
15+
import triton.language as tl
16+
17+
18+
@triton.jit(do_not_specialize=["max_spec_len"])
19+
def rejection_greedy_sample_kernel(
20+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
21+
cu_num_draft_tokens_ptr, # [batch_size]
22+
draft_token_ids_ptr, # [num_tokens]
23+
target_argmax_ptr, # [num_tokens]
24+
bonus_token_ids_ptr, # [batch_size]
25+
is_greedy_ptr, # [batch_size] or None
26+
max_spec_len,
27+
):
28+
req_idx = tl.program_id(0)
29+
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
30+
# re-compilation may happen during runtime when is_greedy_ptr is None.
31+
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
32+
req_idx)
33+
if not is_greedy:
34+
# Early exit for non-greedy sampling requests.
35+
return
36+
37+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
38+
req_idx - 1)
39+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
40+
num_draft_tokens = end_idx - start_idx
41+
42+
rejected = False
43+
for pos in range(num_draft_tokens):
44+
if not rejected:
45+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
46+
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
47+
tl.store(
48+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
49+
target_argmax_id,
50+
)
51+
if draft_token_id != target_argmax_id:
52+
# Reject.
53+
rejected = True
54+
55+
if not rejected:
56+
# If all tokens are accepted, append the bonus token.
57+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
58+
tl.store(
59+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
60+
num_draft_tokens,
61+
bonus_token_id,
62+
)
63+
64+
65+
@triton.jit(do_not_specialize=["max_spec_len"])
66+
def rejection_random_sample_block_verify_kernel(
67+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
68+
cu_num_draft_tokens_ptr, # [batch_size]
69+
draft_token_ids_ptr, # [num_tokens]
70+
draft_probs_ptr, # [num_tokens, vocab_size] or None
71+
target_probs_ptr, # [num_tokens, vocab_size]
72+
bonus_token_ids_ptr, # [batch_size]
73+
uniform_probs_ptr, # [num_tokens]
74+
is_greedy_ptr, # [batch_size]
75+
max_spec_len,
76+
vocab_size,
77+
NO_DRAFT_PROBS: tl.constexpr,
78+
SUB_BLOCK: tl.constexpr = 1500,
79+
):
80+
req_idx = tl.program_id(0)
81+
is_greedy = tl.load(is_greedy_ptr + req_idx)
82+
if is_greedy:
83+
# Early exit for greedy sampling requests.
84+
return
85+
86+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
87+
req_idx - 1)
88+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
89+
num_draft_tokens = end_idx - start_idx
90+
91+
rejected = False
92+
pi = 1.0
93+
uniform_prob = 1.0
94+
last_accepted_token_pos = -1
95+
96+
for pos in range(num_draft_tokens):
97+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
98+
target_prob = tl.load(target_probs_ptr +
99+
(start_idx + pos) * vocab_size +
100+
draft_token_id)
101+
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
102+
uniform_prob = uniform_prob * tmp_uniform_prob
103+
104+
if NO_DRAFT_PROBS:
105+
draft_prob = 1
106+
else:
107+
draft_prob = tl.load(draft_probs_ptr +
108+
(start_idx + pos) * vocab_size +
109+
draft_token_id)
110+
111+
pi = min(pi * target_prob / draft_prob, 1.0)
112+
if draft_prob > 0 and pi >= uniform_prob:
113+
last_accepted_token_pos = pos
114+
rejected = False
115+
else:
116+
rejected = True
117+
118+
if last_accepted_token_pos > -1:
119+
for pos in range(last_accepted_token_pos + 1):
120+
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
121+
tl.store(
122+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
123+
token_id)
124+
125+
if rejected:
126+
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
127+
global_recovered_id = -1
128+
global_max_p = -1.0
129+
for loop_i in range(loop):
130+
vocab_start = loop_i * SUB_BLOCK
131+
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
132+
tmp_target_prob = tl.load(
133+
target_probs_ptr +
134+
(start_idx + last_accepted_token_pos + 1) * vocab_size +
135+
vocab_offset,
136+
mask=vocab_offset < vocab_size,
137+
other=0)
138+
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
139+
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
140+
if max_p > global_max_p:
141+
global_max_p = max_p
142+
global_recovered_id = vocab_start + recovered_id
143+
tl.store(
144+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
145+
last_accepted_token_pos + 1, global_recovered_id)
146+
else:
147+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
148+
tl.store(
149+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
150+
num_draft_tokens, bonus_token_id)
151+
152+
153+
TRITON_ASCEND_AVAILABLE = True
154+
except ImportError:
155+
TRITON_ASCEND_AVAILABLE = False
156+
13157
PLACEHOLDER_TOKEN_ID = -1
14158
GREEDY_TEMPERATURE = -1
15159
# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +278,9 @@ def rejection_sample(
134278
assert bonus_token_ids.is_contiguous()
135279
assert target_probs.shape == (num_tokens, vocab_size)
136280

281+
# Switch of Block Verify: when MTP>=3, using block verify for rejection sampler.
282+
using_block_verify = max_spec_len >= 3
283+
137284
# Create output buffer.
138285
output_token_ids = torch.empty(
139286
(batch_size, max_spec_len + 1),
@@ -149,25 +296,36 @@ def rejection_sample(
149296
if not sampling_metadata.all_random:
150297
# Rejection sampling for greedy sampling requests.
151298
target_argmax = target_probs.argmax(dim=-1)
152-
if min(num_draft_tokens) == 1 and max(
153-
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
154-
rejection_greedy_sample_spec_len_1_pytorch(
155-
output_token_ids,
156-
draft_token_ids,
157-
target_argmax,
158-
bonus_token_ids,
159-
)
160-
else:
161-
rejection_greedy_sample_pytorch(
299+
if TRITON_ASCEND_AVAILABLE:
300+
rejection_greedy_sample_kernel[(batch_size, )](
162301
output_token_ids,
163302
cu_num_draft_tokens,
164303
draft_token_ids,
165304
target_argmax,
166305
bonus_token_ids,
167-
num_draft_tokens,
168-
max_spec_len,
169306
is_greedy,
307+
max_spec_len,
170308
)
309+
else:
310+
if min(num_draft_tokens) == 1 and max(
311+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
312+
rejection_greedy_sample_spec_len_1_pytorch(
313+
output_token_ids,
314+
draft_token_ids,
315+
target_argmax,
316+
bonus_token_ids,
317+
)
318+
else:
319+
rejection_greedy_sample_pytorch(
320+
output_token_ids,
321+
cu_num_draft_tokens,
322+
draft_token_ids,
323+
target_argmax,
324+
bonus_token_ids,
325+
num_draft_tokens,
326+
max_spec_len,
327+
is_greedy,
328+
)
171329
if sampling_metadata.all_greedy:
172330
return output_token_ids
173331

@@ -178,37 +336,68 @@ def rejection_sample(
178336
num_draft_tokens,
179337
sampling_metadata.generators,
180338
device,
181-
)
182-
183-
# Sample recovered tokens for each position.
184-
# [num_tokens]
185-
recovered_token_ids = sample_recovered_tokens(
186-
max_spec_len,
187-
num_draft_tokens,
188-
cu_num_draft_tokens,
189-
draft_token_ids,
190-
draft_probs,
191-
target_probs,
192-
sampling_metadata,
193-
device,
194-
)
339+
).to(torch.float32)
340+
341+
if not using_block_verify:
342+
# Sample recovered tokens for each position.
343+
# [num_tokens]
344+
recovered_token_ids = sample_recovered_tokens(
345+
max_spec_len,
346+
num_draft_tokens,
347+
cu_num_draft_tokens,
348+
draft_token_ids,
349+
draft_probs,
350+
target_probs,
351+
sampling_metadata,
352+
device,
353+
)
195354

196-
# Rejection sampling for random sampling requests.
197-
rejection_random_sample_pytorch(
198-
output_token_ids,
199-
cu_num_draft_tokens,
200-
draft_token_ids,
201-
draft_probs,
202-
target_probs,
203-
bonus_token_ids,
204-
recovered_token_ids,
205-
uniform_probs,
206-
is_greedy,
207-
max_spec_len,
208-
vocab_size,
209-
IS_NGRAM=draft_probs is None,
210-
# num_warps=1,
211-
)
355+
# Rejection sampling for random sampling requests.
356+
rejection_random_sample_pytorch(
357+
output_token_ids,
358+
cu_num_draft_tokens,
359+
draft_token_ids,
360+
draft_probs,
361+
target_probs,
362+
bonus_token_ids,
363+
recovered_token_ids,
364+
uniform_probs,
365+
is_greedy,
366+
max_spec_len,
367+
vocab_size,
368+
IS_NGRAM=draft_probs is None,
369+
# num_warps=1,
370+
)
371+
else:
372+
# MagicMTP: Improving acceptance rate with Block Verify.
373+
if TRITON_ASCEND_AVAILABLE:
374+
rejection_random_sample_block_verify_kernel[(batch_size, )](
375+
output_token_ids,
376+
cu_num_draft_tokens,
377+
draft_token_ids,
378+
draft_probs,
379+
target_probs,
380+
bonus_token_ids,
381+
uniform_probs,
382+
is_greedy,
383+
max_spec_len,
384+
vocab_size,
385+
NO_DRAFT_PROBS=draft_probs is None,
386+
multibuffer=True,
387+
)
388+
else:
389+
rejection_random_sample_block_verify_pytorch(output_token_ids,
390+
cu_num_draft_tokens,
391+
draft_token_ids,
392+
draft_probs,
393+
target_probs,
394+
bonus_token_ids,
395+
uniform_probs,
396+
is_greedy,
397+
max_spec_len,
398+
vocab_size,
399+
IS_NGRAM=draft_probs
400+
is None)
212401
return output_token_ids
213402

214403

@@ -504,4 +693,69 @@ def sample_recovered_tokens_pytorch(
504693
target_probs[token_idx, draft_token_id] = orig_prob
505694

506695

696+
def rejection_random_sample_block_verify_pytorch(
697+
output_token_ids, # [batch_size, max_spec_len + 1]
698+
cu_num_draft_tokens, # [batch_size]
699+
draft_token_ids, # [num_tokens]
700+
draft_probs, # [num_tokens, vocab_size] or None
701+
target_probs, # [num_tokens, vocab_size]
702+
bonus_token_ids, # [batch_size]
703+
uniform_probs, # [num_tokens]
704+
is_greedy, # [batch_size]
705+
max_spec_len,
706+
vocab_size,
707+
IS_NGRAM=False,
708+
):
709+
batch_size = output_token_ids.shape[0]
710+
711+
for req_idx in range(batch_size):
712+
if is_greedy[req_idx]:
713+
continue
714+
715+
if req_idx == 0:
716+
start_idx = 0
717+
else:
718+
start_idx = cu_num_draft_tokens[req_idx - 1].item()
719+
end_idx = cu_num_draft_tokens[req_idx].item()
720+
num_draft_tokens = end_idx - start_idx
721+
722+
rejected = False
723+
pi = 1.0
724+
uniform_prob = 1.0
725+
last_accepted_token_pos = -1
726+
for pos in range(num_draft_tokens):
727+
draft_token_id = draft_token_ids[start_idx + pos].item()
728+
729+
target_prob = target_probs[start_idx + pos, draft_token_id].item()
730+
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()
731+
732+
if IS_NGRAM:
733+
draft_prob = 1.0
734+
else:
735+
draft_prob = draft_probs[start_idx + pos,
736+
draft_token_id].item()
737+
738+
pi = min(pi * target_prob / draft_prob, 1.0)
739+
740+
if draft_prob > 0 and pi >= uniform_prob:
741+
last_accepted_token_pos = pos
742+
rejected = False
743+
else:
744+
rejected = True
745+
746+
if last_accepted_token_pos > -1:
747+
for pos in range(last_accepted_token_pos + 1):
748+
draft_token_id = draft_token_ids[start_idx + pos].item()
749+
output_token_ids[req_idx, pos] = draft_token_id
750+
751+
if rejected:
752+
recovered_token_id = torch.argmax(
753+
target_probs[start_idx + last_accepted_token_pos + 1]).item()
754+
output_token_ids[req_idx,
755+
last_accepted_token_pos + 1] = recovered_token_id
756+
else:
757+
bonus_token_id = bonus_token_ids[req_idx].item()
758+
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
759+
760+
507761
rs.expand_batch_to_tokens = expand_batch_to_tokens

0 commit comments

Comments
 (0)