Skip to content

Commit 74390bb

Browse files
committed
Add MagicMTP(block verify) and Triton optimization
Signed-off-by: chenaoxuan <[email protected]>
1 parent 941d54a commit 74390bb

File tree

2 files changed

+296
-42
lines changed

2 files changed

+296
-42
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ ignore_missing_imports = True
1515
[mypy-lm_eval.*]
1616
ignore_missing_imports = True
1717

18+
[mypy-triton.*]
19+
ignore_missing_imports = True
20+
1821
[mypy-msprobe.*]
1922
ignore_missing_imports = True
2023
allow_untyped_imports = True

vllm_ascend/sample/rejection_sampler.py

Lines changed: 293 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,147 @@
1010
generate_uniform_probs)
1111
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1212

13+
try:
14+
import triton
15+
import triton.language as tl
16+
17+
@triton.jit(do_not_specialize=["max_spec_len"])
18+
def rejection_greedy_sample_kernel(
19+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
20+
cu_num_draft_tokens_ptr, # [batch_size]
21+
draft_token_ids_ptr, # [num_tokens]
22+
target_argmax_ptr, # [num_tokens]
23+
bonus_token_ids_ptr, # [batch_size]
24+
is_greedy_ptr, # [batch_size] or None
25+
max_spec_len,
26+
):
27+
req_idx = tl.program_id(0)
28+
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
29+
# re-compilation may happen during runtime when is_greedy_ptr is None.
30+
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
31+
req_idx)
32+
if not is_greedy:
33+
# Early exit for non-greedy sampling requests.
34+
return
35+
36+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
37+
req_idx - 1)
38+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
39+
num_draft_tokens = end_idx - start_idx
40+
41+
rejected = False
42+
for pos in range(num_draft_tokens):
43+
if not rejected:
44+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
45+
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
46+
tl.store(
47+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
48+
target_argmax_id,
49+
)
50+
if draft_token_id != target_argmax_id:
51+
# Reject.
52+
rejected = True
53+
54+
if not rejected:
55+
# If all tokens are accepted, append the bonus token.
56+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
57+
tl.store(
58+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
59+
num_draft_tokens,
60+
bonus_token_id,
61+
)
62+
63+
@triton.jit(do_not_specialize=["max_spec_len"])
64+
def rejection_random_sample_block_verify_kernel(
65+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
66+
cu_num_draft_tokens_ptr, # [batch_size]
67+
draft_token_ids_ptr, # [num_tokens]
68+
draft_probs_ptr, # [num_tokens, vocab_size] or None
69+
target_probs_ptr, # [num_tokens, vocab_size]
70+
bonus_token_ids_ptr, # [batch_size]
71+
uniform_probs_ptr, # [num_tokens]
72+
is_greedy_ptr, # [batch_size]
73+
max_spec_len,
74+
vocab_size,
75+
NO_DRAFT_PROBS: tl.constexpr,
76+
SUB_BLOCK: tl.constexpr = 1500,
77+
):
78+
req_idx = tl.program_id(0)
79+
is_greedy = tl.load(is_greedy_ptr + req_idx)
80+
if is_greedy:
81+
# Early exit for greedy sampling requests.
82+
return
83+
84+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
85+
req_idx - 1)
86+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
87+
num_draft_tokens = end_idx - start_idx
88+
89+
rejected = False
90+
pi = 1.0
91+
uniform_prob = 1.0
92+
last_accepted_token_pos = -1
93+
94+
for pos in range(num_draft_tokens):
95+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
96+
target_prob = tl.load(target_probs_ptr +
97+
(start_idx + pos) * vocab_size +
98+
draft_token_id)
99+
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
100+
uniform_prob = uniform_prob * tmp_uniform_prob
101+
102+
if NO_DRAFT_PROBS:
103+
draft_prob = 1
104+
else:
105+
draft_prob = tl.load(draft_probs_ptr +
106+
(start_idx + pos) * vocab_size +
107+
draft_token_id)
108+
109+
pi = min(pi * target_prob / draft_prob, 1.0)
110+
if draft_prob > 0 and pi >= uniform_prob:
111+
last_accepted_token_pos = pos
112+
rejected = False
113+
else:
114+
rejected = True
115+
116+
if last_accepted_token_pos > -1:
117+
for pos in range(last_accepted_token_pos + 1):
118+
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
119+
tl.store(
120+
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
121+
token_id)
122+
123+
if rejected:
124+
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
125+
global_recovered_id = -1
126+
global_max_p = -1.0
127+
for loop_i in range(loop):
128+
vocab_start = loop_i * SUB_BLOCK
129+
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
130+
tmp_target_prob = tl.load(
131+
target_probs_ptr +
132+
(start_idx + last_accepted_token_pos + 1) * vocab_size +
133+
vocab_offset,
134+
mask=vocab_offset < vocab_size,
135+
other=0)
136+
recovered_id = tl.argmax(tmp_target_prob, axis=-1)
137+
max_p = tl.get_element(tmp_target_prob, (recovered_id, ))
138+
if max_p > global_max_p:
139+
global_max_p = max_p
140+
global_recovered_id = vocab_start + recovered_id
141+
tl.store(
142+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
143+
last_accepted_token_pos + 1, global_recovered_id)
144+
else:
145+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
146+
tl.store(
147+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
148+
num_draft_tokens, bonus_token_id)
149+
150+
TRITON_ASCEND_AVAILABLE = True
151+
except ImportError:
152+
TRITON_ASCEND_AVAILABLE = False
153+
13154
PLACEHOLDER_TOKEN_ID = -1
14155
GREEDY_TEMPERATURE = -1
15156
# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +275,9 @@ def rejection_sample(
134275
assert bonus_token_ids.is_contiguous()
135276
assert target_probs.shape == (num_tokens, vocab_size)
136277

278+
# Switch of Block Verify: when MTP>=3, using block verify for rejection sampler.
279+
using_block_verify = max_spec_len >= 3
280+
137281
# Create output buffer.
138282
output_token_ids = torch.empty(
139283
(batch_size, max_spec_len + 1),
@@ -149,25 +293,36 @@ def rejection_sample(
149293
if not sampling_metadata.all_random:
150294
# Rejection sampling for greedy sampling requests.
151295
target_argmax = target_probs.argmax(dim=-1)
152-
if min(num_draft_tokens) == 1 and max(
153-
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
154-
rejection_greedy_sample_spec_len_1_pytorch(
155-
output_token_ids,
156-
draft_token_ids,
157-
target_argmax,
158-
bonus_token_ids,
159-
)
160-
else:
161-
rejection_greedy_sample_pytorch(
296+
if TRITON_ASCEND_AVAILABLE:
297+
rejection_greedy_sample_kernel[(batch_size, )](
162298
output_token_ids,
163299
cu_num_draft_tokens,
164300
draft_token_ids,
165301
target_argmax,
166302
bonus_token_ids,
167-
num_draft_tokens,
168-
max_spec_len,
169303
is_greedy,
304+
max_spec_len,
170305
)
306+
else:
307+
if min(num_draft_tokens) == 1 and max(
308+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
309+
rejection_greedy_sample_spec_len_1_pytorch(
310+
output_token_ids,
311+
draft_token_ids,
312+
target_argmax,
313+
bonus_token_ids,
314+
)
315+
else:
316+
rejection_greedy_sample_pytorch(
317+
output_token_ids,
318+
cu_num_draft_tokens,
319+
draft_token_ids,
320+
target_argmax,
321+
bonus_token_ids,
322+
num_draft_tokens,
323+
max_spec_len,
324+
is_greedy,
325+
)
171326
if sampling_metadata.all_greedy:
172327
return output_token_ids
173328

@@ -178,37 +333,68 @@ def rejection_sample(
178333
num_draft_tokens,
179334
sampling_metadata.generators,
180335
device,
181-
)
182-
183-
# Sample recovered tokens for each position.
184-
# [num_tokens]
185-
recovered_token_ids = sample_recovered_tokens(
186-
max_spec_len,
187-
num_draft_tokens,
188-
cu_num_draft_tokens,
189-
draft_token_ids,
190-
draft_probs,
191-
target_probs,
192-
sampling_metadata,
193-
device,
194-
)
336+
).to(torch.float32)
337+
338+
if not using_block_verify:
339+
# Sample recovered tokens for each position.
340+
# [num_tokens]
341+
recovered_token_ids = sample_recovered_tokens(
342+
max_spec_len,
343+
num_draft_tokens,
344+
cu_num_draft_tokens,
345+
draft_token_ids,
346+
draft_probs,
347+
target_probs,
348+
sampling_metadata,
349+
device,
350+
)
195351

196-
# Rejection sampling for random sampling requests.
197-
rejection_random_sample_pytorch(
198-
output_token_ids,
199-
cu_num_draft_tokens,
200-
draft_token_ids,
201-
draft_probs,
202-
target_probs,
203-
bonus_token_ids,
204-
recovered_token_ids,
205-
uniform_probs,
206-
is_greedy,
207-
max_spec_len,
208-
vocab_size,
209-
IS_NGRAM=draft_probs is None,
210-
# num_warps=1,
211-
)
352+
# Rejection sampling for random sampling requests.
353+
rejection_random_sample_pytorch(
354+
output_token_ids,
355+
cu_num_draft_tokens,
356+
draft_token_ids,
357+
draft_probs,
358+
target_probs,
359+
bonus_token_ids,
360+
recovered_token_ids,
361+
uniform_probs,
362+
is_greedy,
363+
max_spec_len,
364+
vocab_size,
365+
IS_NGRAM=draft_probs is None,
366+
# num_warps=1,
367+
)
368+
else:
369+
# MagicMTP: Improving acceptance rate with Block Verify.
370+
if TRITON_ASCEND_AVAILABLE:
371+
rejection_random_sample_block_verify_kernel[(batch_size, )](
372+
output_token_ids,
373+
cu_num_draft_tokens,
374+
draft_token_ids,
375+
draft_probs,
376+
target_probs,
377+
bonus_token_ids,
378+
uniform_probs,
379+
is_greedy,
380+
max_spec_len,
381+
vocab_size,
382+
NO_DRAFT_PROBS=draft_probs is None,
383+
multibuffer=True,
384+
)
385+
else:
386+
rejection_random_sample_block_verify_pytorch(output_token_ids,
387+
cu_num_draft_tokens,
388+
draft_token_ids,
389+
draft_probs,
390+
target_probs,
391+
bonus_token_ids,
392+
uniform_probs,
393+
is_greedy,
394+
max_spec_len,
395+
vocab_size,
396+
IS_NGRAM=draft_probs
397+
is None)
212398
return output_token_ids
213399

214400

@@ -504,4 +690,69 @@ def sample_recovered_tokens_pytorch(
504690
target_probs[token_idx, draft_token_id] = orig_prob
505691

506692

693+
def rejection_random_sample_block_verify_pytorch(
694+
output_token_ids, # [batch_size, max_spec_len + 1]
695+
cu_num_draft_tokens, # [batch_size]
696+
draft_token_ids, # [num_tokens]
697+
draft_probs, # [num_tokens, vocab_size] or None
698+
target_probs, # [num_tokens, vocab_size]
699+
bonus_token_ids, # [batch_size]
700+
uniform_probs, # [num_tokens]
701+
is_greedy, # [batch_size]
702+
max_spec_len,
703+
vocab_size,
704+
IS_NGRAM=False,
705+
):
706+
batch_size = output_token_ids.shape[0]
707+
708+
for req_idx in range(batch_size):
709+
if is_greedy[req_idx]:
710+
continue
711+
712+
if req_idx == 0:
713+
start_idx = 0
714+
else:
715+
start_idx = cu_num_draft_tokens[req_idx - 1].item()
716+
end_idx = cu_num_draft_tokens[req_idx].item()
717+
num_draft_tokens = end_idx - start_idx
718+
719+
rejected = False
720+
pi = 1.0
721+
uniform_prob = 1.0
722+
last_accepted_token_pos = -1
723+
for pos in range(num_draft_tokens):
724+
draft_token_id = draft_token_ids[start_idx + pos].item()
725+
726+
target_prob = target_probs[start_idx + pos, draft_token_id].item()
727+
uniform_prob = uniform_prob * uniform_probs[start_idx + pos].item()
728+
729+
if IS_NGRAM:
730+
draft_prob = 1.0
731+
else:
732+
draft_prob = draft_probs[start_idx + pos,
733+
draft_token_id].item()
734+
735+
pi = min(pi * target_prob / draft_prob, 1.0)
736+
737+
if draft_prob > 0 and pi >= uniform_prob:
738+
last_accepted_token_pos = pos
739+
rejected = False
740+
else:
741+
rejected = True
742+
743+
if last_accepted_token_pos > -1:
744+
for pos in range(last_accepted_token_pos + 1):
745+
draft_token_id = draft_token_ids[start_idx + pos].item()
746+
output_token_ids[req_idx, pos] = draft_token_id
747+
748+
if rejected:
749+
recovered_token_id = torch.argmax(
750+
target_probs[start_idx + last_accepted_token_pos + 1]).item()
751+
output_token_ids[req_idx,
752+
last_accepted_token_pos + 1] = recovered_token_id
753+
else:
754+
bonus_token_id = bonus_token_ids[req_idx].item()
755+
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
756+
757+
507758
rs.expand_batch_to_tokens = expand_batch_to_tokens

0 commit comments

Comments
 (0)