Skip to content

Commit 029d9c3

Browse files
committed
Add MagicMTP(block verify) and Triton optimization
Signed-off-by: chenaoxuan <[email protected]>
1 parent 84d7f5a commit 029d9c3

File tree

2 files changed

+292
-42
lines changed

2 files changed

+292
-42
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ ignore_missing_imports = True
1515
[mypy-lm_eval.*]
1616
ignore_missing_imports = True
1717

18+
[mypy-triton.*]
19+
ignore_missing_imports = True
20+
1821
[mypy-msprobe.*]
1922
ignore_missing_imports = True
2023
allow_untyped_imports = True

vllm_ascend/sample/rejection_sampler.py

Lines changed: 289 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,149 @@
44
import torch
55
import torch.nn as nn
66
import vllm.v1.sample.rejection_sampler as rs
7+
from vllm.triton_utils import HAS_TRITON, tl, triton
78
from vllm.v1.sample.metadata import SamplingMetadata
89
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
910
apply_sampling_constraints,
1011
generate_uniform_probs)
1112
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1213

14+
if HAS_TRITON:
15+
16+
@triton.jit(do_not_specialize=["max_spec_len"])
17+
def rejection_greedy_sample_kernel(
18+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
19+
cu_num_draft_tokens_ptr, # [batch_size]
20+
draft_token_ids_ptr, # [num_tokens]
21+
target_argmax_ptr, # [num_tokens]
22+
bonus_token_ids_ptr, # [batch_size]
23+
is_greedy_ptr, # [batch_size] or None
24+
max_spec_len,
25+
):
26+
req_idx = tl.program_id(0)
27+
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
28+
# re-compilation may happen during runtime when is_greedy_ptr is None.
29+
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
30+
req_idx)
31+
if not is_greedy:
32+
# Early exit for non-greedy sampling requests.
33+
return
34+
35+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
36+
req_idx - 1)
37+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
38+
num_draft_tokens = end_idx - start_idx
39+
40+
rejected = False
41+
for pos in range(num_draft_tokens):
42+
if not rejected:
43+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
44+
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
45+
tl.store(
46+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
47+
target_argmax_id,
48+
)
49+
if draft_token_id != target_argmax_id:
50+
# Reject.
51+
rejected = True
52+
53+
if not rejected:
54+
# If all tokens are accepted, append the bonus token.
55+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
56+
tl.store(
57+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
58+
num_draft_tokens,
59+
bonus_token_id,
60+
)
61+
62+
@triton.jit(do_not_specialize=["max_spec_len"])
63+
def rejection_random_sample_block_verify_kernel(
64+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
65+
cu_num_draft_tokens_ptr, # [batch_size]
66+
draft_token_ids_ptr, # [num_tokens]
67+
draft_probs_ptr, # [num_tokens, vocab_size] or None
68+
target_probs_ptr, # [num_tokens, vocab_size]
69+
bonus_token_ids_ptr, # [batch_size]
70+
uniform_probs_ptr, # [num_tokens]
71+
is_greedy_ptr, # [batch_size]
72+
max_spec_len,
73+
vocab_size,
74+
NO_DRAFT_PROBS: tl.constexpr,
75+
SUB_BLOCK: tl.constexpr = 1500,
76+
):
77+
req_idx = tl.program_id(0)
78+
is_greedy = tl.load(is_greedy_ptr + req_idx)
79+
if is_greedy:
80+
# Early exit for greedy sampling requests.
81+
return
82+
83+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
84+
req_idx - 1)
85+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
86+
num_draft_tokens = end_idx - start_idx
87+
88+
rejected = False
89+
pi = 1.0
90+
uniform_prob = 1.0
91+
last_accepted_token_pos = -1
92+
93+
for pos in range(num_draft_tokens):
94+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
95+
target_prob = tl.load(target_probs_ptr +
96+
(start_idx + pos) * vocab_size +
97+
draft_token_id)
98+
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
99+
uniform_prob = uniform_prob * tmp_uniform_prob
100+
101+
if NO_DRAFT_PROBS:
102+
draft_prob = 1
103+
else:
104+
draft_prob = tl.load(draft_probs_ptr +
105+
(start_idx + pos) * vocab_size +
106+
draft_token_id)
107+
108+
pi = min(pi * target_prob / draft_prob, 1.0)
109+
if draft_prob > 0 and pi >= uniform_prob:
110+
last_accepted_token_pos = pos
111+
rejected = False
112+
else:
113+
rejected = True
114+
115+
if last_accepted_token_pos > -1:
116+
for pos in range(last_accepted_token_pos + 1):
117+
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
118+
tl.store(
119+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
120+
token_id)
121+
122+
if rejected:
123+
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
124+
global_recovered_id = -1
125+
global_max_p = -1.0
126+
for loop_i in range(loop):
127+
vocab_start = loop_i * SUB_BLOCK
128+
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
129+
tmp_target_prob = tl.load(
130+
target_probs_ptr +
131+
(start_idx + last_accepted_token_pos + 1) * vocab_size +
132+
vocab_offset,
133+
mask=vocab_offset < vocab_size,
134+
other=0)
135+
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
136+
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
137+
if max_p > global_max_p:
138+
global_max_p = max_p
139+
global_recovered_id = vocab_start + recovered_id
140+
tl.store(
141+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
142+
last_accepted_token_pos + 1, global_recovered_id)
143+
else:
144+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
145+
tl.store(
146+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
147+
num_draft_tokens, bonus_token_id)
148+
149+
13150
PLACEHOLDER_TOKEN_ID = -1
14151
GREEDY_TEMPERATURE = -1
15152
# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +271,9 @@ def rejection_sample(
134271
assert bonus_token_ids.is_contiguous()
135272
assert target_probs.shape == (num_tokens, vocab_size)
136273

274+
# When num_speculative_tokens>=3, using block verify.
275+
using_block_verify = max_spec_len >= 3
276+
137277
# Create output buffer.
138278
output_token_ids = torch.empty(
139279
(batch_size, max_spec_len + 1),
@@ -149,25 +289,36 @@ def rejection_sample(
149289
if not sampling_metadata.all_random:
150290
# Rejection sampling for greedy sampling requests.
151291
target_argmax = target_probs.argmax(dim=-1)
152-
if min(num_draft_tokens) == 1 and max(
153-
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
154-
rejection_greedy_sample_spec_len_1_pytorch(
155-
output_token_ids,
156-
draft_token_ids,
157-
target_argmax,
158-
bonus_token_ids,
159-
)
160-
else:
161-
rejection_greedy_sample_pytorch(
292+
if HAS_TRITON:
293+
rejection_greedy_sample_kernel[(batch_size, )](
162294
output_token_ids,
163295
cu_num_draft_tokens,
164296
draft_token_ids,
165297
target_argmax,
166298
bonus_token_ids,
167-
num_draft_tokens,
168-
max_spec_len,
169299
is_greedy,
300+
max_spec_len,
170301
)
302+
else:
303+
if min(num_draft_tokens) == 1 and max(
304+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
305+
rejection_greedy_sample_spec_len_1_pytorch(
306+
output_token_ids,
307+
draft_token_ids,
308+
target_argmax,
309+
bonus_token_ids,
310+
)
311+
else:
312+
rejection_greedy_sample_pytorch(
313+
output_token_ids,
314+
cu_num_draft_tokens,
315+
draft_token_ids,
316+
target_argmax,
317+
bonus_token_ids,
318+
num_draft_tokens,
319+
max_spec_len,
320+
is_greedy,
321+
)
171322
if sampling_metadata.all_greedy:
172323
return output_token_ids
173324

@@ -178,37 +329,68 @@ def rejection_sample(
178329
num_draft_tokens,
179330
sampling_metadata.generators,
180331
device,
181-
)
182-
183-
# Sample recovered tokens for each position.
184-
# [num_tokens]
185-
recovered_token_ids = sample_recovered_tokens(
186-
max_spec_len,
187-
num_draft_tokens,
188-
cu_num_draft_tokens,
189-
draft_token_ids,
190-
draft_probs,
191-
target_probs,
192-
sampling_metadata,
193-
device,
194-
)
332+
).to(torch.float32)
333+
334+
if not using_block_verify:
335+
# Sample recovered tokens for each position.
336+
# [num_tokens]
337+
recovered_token_ids = sample_recovered_tokens(
338+
max_spec_len,
339+
num_draft_tokens,
340+
cu_num_draft_tokens,
341+
draft_token_ids,
342+
draft_probs,
343+
target_probs,
344+
sampling_metadata,
345+
device,
346+
)
195347

196-
# Rejection sampling for random sampling requests.
197-
rejection_random_sample_pytorch(
198-
output_token_ids,
199-
cu_num_draft_tokens,
200-
draft_token_ids,
201-
draft_probs,
202-
target_probs,
203-
bonus_token_ids,
204-
recovered_token_ids,
205-
uniform_probs,
206-
is_greedy,
207-
max_spec_len,
208-
vocab_size,
209-
IS_NGRAM=draft_probs is None,
210-
# num_warps=1,
211-
)
348+
# Rejection sampling for random sampling requests.
349+
rejection_random_sample_pytorch(
350+
output_token_ids,
351+
cu_num_draft_tokens,
352+
draft_token_ids,
353+
draft_probs,
354+
target_probs,
355+
bonus_token_ids,
356+
recovered_token_ids,
357+
uniform_probs,
358+
is_greedy,
359+
max_spec_len,
360+
vocab_size,
361+
IS_NGRAM=draft_probs is None,
362+
# num_warps=1,
363+
)
364+
else:
365+
# MagicMTP: Improving acceptance rate with Block Verify.
366+
if HAS_TRITON:
367+
rejection_random_sample_block_verify_kernel[(batch_size, )](
368+
output_token_ids,
369+
cu_num_draft_tokens,
370+
draft_token_ids,
371+
draft_probs,
372+
target_probs,
373+
bonus_token_ids,
374+
uniform_probs,
375+
is_greedy,
376+
max_spec_len,
377+
vocab_size,
378+
NO_DRAFT_PROBS=draft_probs is None,
379+
multibuffer=True,
380+
)
381+
else:
382+
rejection_random_sample_block_verify_pytorch(output_token_ids,
383+
cu_num_draft_tokens,
384+
draft_token_ids,
385+
draft_probs,
386+
target_probs,
387+
bonus_token_ids,
388+
uniform_probs,
389+
is_greedy,
390+
max_spec_len,
391+
vocab_size,
392+
IS_NGRAM=draft_probs
393+
is None)
212394
return output_token_ids
213395

214396

@@ -504,4 +686,69 @@ def sample_recovered_tokens_pytorch(
504686
target_probs[token_idx, draft_token_id] = orig_prob
505687

506688

689+
def rejection_random_sample_block_verify_pytorch(
690+
output_token_ids, # [batch_size, max_spec_len + 1]
691+
cu_num_draft_tokens, # [batch_size]
692+
draft_token_ids, # [num_tokens]
693+
draft_probs, # [num_tokens, vocab_size] or None
694+
target_probs, # [num_tokens, vocab_size]
695+
bonus_token_ids, # [batch_size]
696+
uniform_probs, # [num_tokens]
697+
is_greedy, # [batch_size]
698+
max_spec_len,
699+
vocab_size,
700+
IS_NGRAM=False,
701+
):
702+
batch_size = output_token_ids.shape[0]
703+
704+
for req_idx in range(batch_size):
705+
if is_greedy[req_idx]:
706+
continue
707+
708+
if req_idx == 0:
709+
start_idx = 0
710+
else:
711+
start_idx = cu_num_draft_tokens[req_idx - 1].item()
712+
end_idx = cu_num_draft_tokens[req_idx].item()
713+
num_draft_tokens = end_idx - start_idx
714+
715+
rejected = False
716+
pi = 1.0
717+
uniform_prob = 1.0
718+
last_accepted_token_pos = -1
719+
for pos in range(num_draft_tokens):
720+
draft_token_id = draft_token_ids[start_idx + pos].item()
721+
722+
target_prob = target_probs[start_idx + pos, draft_token_id].item()
723+
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()
724+
725+
if IS_NGRAM:
726+
draft_prob = 1.0
727+
else:
728+
draft_prob = draft_probs[start_idx + pos,
729+
draft_token_id].item()
730+
731+
pi = min(pi * target_prob / draft_prob, 1.0)
732+
733+
if draft_prob > 0 and pi >= uniform_prob:
734+
last_accepted_token_pos = pos
735+
rejected = False
736+
else:
737+
rejected = True
738+
739+
if last_accepted_token_pos > -1:
740+
for pos in range(last_accepted_token_pos + 1):
741+
draft_token_id = draft_token_ids[start_idx + pos].item()
742+
output_token_ids[req_idx, pos] = draft_token_id
743+
744+
if rejected:
745+
recovered_token_id = torch.argmax(
746+
target_probs[start_idx + last_accepted_token_pos + 1]).item()
747+
output_token_ids[req_idx,
748+
last_accepted_token_pos + 1] = recovered_token_id
749+
else:
750+
bonus_token_id = bonus_token_ids[req_idx].item()
751+
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
752+
753+
507754
rs.expand_batch_to_tokens = expand_batch_to_tokens

0 commit comments

Comments
 (0)