44import torch
55import torch .nn as nn
66import vllm .v1 .sample .rejection_sampler as rs
7+ from vllm .triton_utils import HAS_TRITON , tl , triton
78from vllm .v1 .sample .metadata import SamplingMetadata
89from vllm .v1 .sample .rejection_sampler import (RejectionSampler ,
910 apply_sampling_constraints ,
1011 generate_uniform_probs )
1112from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1213
14+ if HAS_TRITON :
15+
16+ @triton .jit (do_not_specialize = ["max_spec_len" ])
17+ def rejection_greedy_sample_kernel (
18+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
19+ cu_num_draft_tokens_ptr , # [batch_size]
20+ draft_token_ids_ptr , # [num_tokens]
21+ target_argmax_ptr , # [num_tokens]
22+ bonus_token_ids_ptr , # [batch_size]
23+ is_greedy_ptr , # [batch_size] or None
24+ max_spec_len ,
25+ ):
26+ req_idx = tl .program_id (0 )
27+ # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
28+ # re-compilation may happen during runtime when is_greedy_ptr is None.
29+ is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr +
30+ req_idx )
31+ if not is_greedy :
32+ # Early exit for non-greedy sampling requests.
33+ return
34+
35+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
36+ req_idx - 1 )
37+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
38+ num_draft_tokens = end_idx - start_idx
39+
40+ rejected = False
41+ for pos in range (num_draft_tokens ):
42+ if not rejected :
43+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
44+ target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
45+ tl .store (
46+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
47+ target_argmax_id ,
48+ )
49+ if draft_token_id != target_argmax_id :
50+ # Reject.
51+ rejected = True
52+
53+ if not rejected :
54+ # If all tokens are accepted, append the bonus token.
55+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
56+ tl .store (
57+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
58+ num_draft_tokens ,
59+ bonus_token_id ,
60+ )
61+
62+ @triton .jit (do_not_specialize = ["max_spec_len" ])
63+ def rejection_random_sample_block_verify_kernel (
64+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
65+ cu_num_draft_tokens_ptr , # [batch_size]
66+ draft_token_ids_ptr , # [num_tokens]
67+ draft_probs_ptr , # [num_tokens, vocab_size] or None
68+ target_probs_ptr , # [num_tokens, vocab_size]
69+ bonus_token_ids_ptr , # [batch_size]
70+ uniform_probs_ptr , # [num_tokens]
71+ is_greedy_ptr , # [batch_size]
72+ max_spec_len ,
73+ vocab_size ,
74+ NO_DRAFT_PROBS : tl .constexpr ,
75+ SUB_BLOCK : tl .constexpr = 1500 ,
76+ ):
77+ req_idx = tl .program_id (0 )
78+ is_greedy = tl .load (is_greedy_ptr + req_idx )
79+ if is_greedy :
80+ # Early exit for greedy sampling requests.
81+ return
82+
83+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
84+ req_idx - 1 )
85+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
86+ num_draft_tokens = end_idx - start_idx
87+
88+ rejected = False
89+ pi = 1.0
90+ uniform_prob = 1.0
91+ last_accepted_token_pos = - 1
92+
93+ for pos in range (num_draft_tokens ):
94+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
95+ target_prob = tl .load (target_probs_ptr +
96+ (start_idx + pos ) * vocab_size +
97+ draft_token_id )
98+ tmp_uniform_prob = tl .load (uniform_probs_ptr + start_idx + pos )
99+ uniform_prob = uniform_prob * tmp_uniform_prob
100+
101+ if NO_DRAFT_PROBS :
102+ draft_prob = 1
103+ else :
104+ draft_prob = tl .load (draft_probs_ptr +
105+ (start_idx + pos ) * vocab_size +
106+ draft_token_id )
107+
108+ pi = min (pi * target_prob / draft_prob , 1.0 )
109+ if draft_prob > 0 and pi >= uniform_prob :
110+ last_accepted_token_pos = pos
111+ rejected = False
112+ else :
113+ rejected = True
114+
115+ if last_accepted_token_pos > - 1 :
116+ for pos in range (last_accepted_token_pos + 1 ):
117+ token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
118+ tl .store (
119+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
120+ token_id )
121+
122+ if rejected :
123+ loop = (vocab_size + SUB_BLOCK - 1 ) // SUB_BLOCK
124+ global_recovered_id = - 1
125+ global_max_p = - 1.0
126+ for loop_i in range (loop ):
127+ vocab_start = loop_i * SUB_BLOCK
128+ vocab_offset = vocab_start + tl .arange (0 , SUB_BLOCK )
129+ tmp_target_prob = tl .load (
130+ target_probs_ptr +
131+ (start_idx + last_accepted_token_pos + 1 ) * vocab_size +
132+ vocab_offset ,
133+ mask = vocab_offset < vocab_size ,
134+ other = 0 )
135+ recovered_id = tl .argmax (tmp_target_prob , axis = - 1 )
136+ max_p = tl .get_element (tmp_target_prob , (recovered_id , ))
137+ if max_p > global_max_p :
138+ global_max_p = max_p
139+ global_recovered_id = vocab_start + recovered_id
140+ tl .store (
141+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
142+ last_accepted_token_pos + 1 , global_recovered_id )
143+ else :
144+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
145+ tl .store (
146+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
147+ num_draft_tokens , bonus_token_id )
148+
149+
13150PLACEHOLDER_TOKEN_ID = - 1
14151GREEDY_TEMPERATURE = - 1
15152# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +271,9 @@ def rejection_sample(
134271 assert bonus_token_ids .is_contiguous ()
135272 assert target_probs .shape == (num_tokens , vocab_size )
136273
274+ # When num_speculative_tokens>=3, using block verify.
275+ using_block_verify = max_spec_len >= 3
276+
137277 # Create output buffer.
138278 output_token_ids = torch .empty (
139279 (batch_size , max_spec_len + 1 ),
@@ -149,25 +289,36 @@ def rejection_sample(
149289 if not sampling_metadata .all_random :
150290 # Rejection sampling for greedy sampling requests.
151291 target_argmax = target_probs .argmax (dim = - 1 )
152- if min (num_draft_tokens ) == 1 and max (
153- num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
154- rejection_greedy_sample_spec_len_1_pytorch (
155- output_token_ids ,
156- draft_token_ids ,
157- target_argmax ,
158- bonus_token_ids ,
159- )
160- else :
161- rejection_greedy_sample_pytorch (
292+ if HAS_TRITON :
293+ rejection_greedy_sample_kernel [(batch_size , )](
162294 output_token_ids ,
163295 cu_num_draft_tokens ,
164296 draft_token_ids ,
165297 target_argmax ,
166298 bonus_token_ids ,
167- num_draft_tokens ,
168- max_spec_len ,
169299 is_greedy ,
300+ max_spec_len ,
170301 )
302+ else :
303+ if min (num_draft_tokens ) == 1 and max (
304+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
305+ rejection_greedy_sample_spec_len_1_pytorch (
306+ output_token_ids ,
307+ draft_token_ids ,
308+ target_argmax ,
309+ bonus_token_ids ,
310+ )
311+ else :
312+ rejection_greedy_sample_pytorch (
313+ output_token_ids ,
314+ cu_num_draft_tokens ,
315+ draft_token_ids ,
316+ target_argmax ,
317+ bonus_token_ids ,
318+ num_draft_tokens ,
319+ max_spec_len ,
320+ is_greedy ,
321+ )
171322 if sampling_metadata .all_greedy :
172323 return output_token_ids
173324
@@ -178,37 +329,68 @@ def rejection_sample(
178329 num_draft_tokens ,
179330 sampling_metadata .generators ,
180331 device ,
181- )
182-
183- # Sample recovered tokens for each position.
184- # [num_tokens]
185- recovered_token_ids = sample_recovered_tokens (
186- max_spec_len ,
187- num_draft_tokens ,
188- cu_num_draft_tokens ,
189- draft_token_ids ,
190- draft_probs ,
191- target_probs ,
192- sampling_metadata ,
193- device ,
194- )
332+ ).to (torch .float32 )
333+
334+ if not using_block_verify :
335+ # Sample recovered tokens for each position.
336+ # [num_tokens]
337+ recovered_token_ids = sample_recovered_tokens (
338+ max_spec_len ,
339+ num_draft_tokens ,
340+ cu_num_draft_tokens ,
341+ draft_token_ids ,
342+ draft_probs ,
343+ target_probs ,
344+ sampling_metadata ,
345+ device ,
346+ )
195347
196- # Rejection sampling for random sampling requests.
197- rejection_random_sample_pytorch (
198- output_token_ids ,
199- cu_num_draft_tokens ,
200- draft_token_ids ,
201- draft_probs ,
202- target_probs ,
203- bonus_token_ids ,
204- recovered_token_ids ,
205- uniform_probs ,
206- is_greedy ,
207- max_spec_len ,
208- vocab_size ,
209- IS_NGRAM = draft_probs is None ,
210- # num_warps=1,
211- )
348+ # Rejection sampling for random sampling requests.
349+ rejection_random_sample_pytorch (
350+ output_token_ids ,
351+ cu_num_draft_tokens ,
352+ draft_token_ids ,
353+ draft_probs ,
354+ target_probs ,
355+ bonus_token_ids ,
356+ recovered_token_ids ,
357+ uniform_probs ,
358+ is_greedy ,
359+ max_spec_len ,
360+ vocab_size ,
361+ IS_NGRAM = draft_probs is None ,
362+ # num_warps=1,
363+ )
364+ else :
365+ # MagicMTP: Improving acceptance rate with Block Verify.
366+ if HAS_TRITON :
367+ rejection_random_sample_block_verify_kernel [(batch_size , )](
368+ output_token_ids ,
369+ cu_num_draft_tokens ,
370+ draft_token_ids ,
371+ draft_probs ,
372+ target_probs ,
373+ bonus_token_ids ,
374+ uniform_probs ,
375+ is_greedy ,
376+ max_spec_len ,
377+ vocab_size ,
378+ NO_DRAFT_PROBS = draft_probs is None ,
379+ multibuffer = True ,
380+ )
381+ else :
382+ rejection_random_sample_block_verify_pytorch (output_token_ids ,
383+ cu_num_draft_tokens ,
384+ draft_token_ids ,
385+ draft_probs ,
386+ target_probs ,
387+ bonus_token_ids ,
388+ uniform_probs ,
389+ is_greedy ,
390+ max_spec_len ,
391+ vocab_size ,
392+ IS_NGRAM = draft_probs
393+ is None )
212394 return output_token_ids
213395
214396
@@ -504,4 +686,69 @@ def sample_recovered_tokens_pytorch(
504686 target_probs [token_idx , draft_token_id ] = orig_prob
505687
506688
689+ def rejection_random_sample_block_verify_pytorch (
690+ output_token_ids , # [batch_size, max_spec_len + 1]
691+ cu_num_draft_tokens , # [batch_size]
692+ draft_token_ids , # [num_tokens]
693+ draft_probs , # [num_tokens, vocab_size] or None
694+ target_probs , # [num_tokens, vocab_size]
695+ bonus_token_ids , # [batch_size]
696+ uniform_probs , # [num_tokens]
697+ is_greedy , # [batch_size]
698+ max_spec_len ,
699+ vocab_size ,
700+ IS_NGRAM = False ,
701+ ):
702+ batch_size = output_token_ids .shape [0 ]
703+
704+ for req_idx in range (batch_size ):
705+ if is_greedy [req_idx ]:
706+ continue
707+
708+ if req_idx == 0 :
709+ start_idx = 0
710+ else :
711+ start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
712+ end_idx = cu_num_draft_tokens [req_idx ].item ()
713+ num_draft_tokens = end_idx - start_idx
714+
715+ rejected = False
716+ pi = 1.0
717+ uniform_prob = 1.0
718+ last_accepted_token_pos = - 1
719+ for pos in range (num_draft_tokens ):
720+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
721+
722+ target_prob = target_probs [start_idx + pos , draft_token_id ].item ()
723+ uniform_prob = uniform_prob * uniform_probs [start_idx + pos ].item ()
724+
725+ if IS_NGRAM :
726+ draft_prob = 1.0
727+ else :
728+ draft_prob = draft_probs [start_idx + pos ,
729+ draft_token_id ].item ()
730+
731+ pi = min (pi * target_prob / draft_prob , 1.0 )
732+
733+ if draft_prob > 0 and pi >= uniform_prob :
734+ last_accepted_token_pos = pos
735+ rejected = False
736+ else :
737+ rejected = True
738+
739+ if last_accepted_token_pos > - 1 :
740+ for pos in range (last_accepted_token_pos + 1 ):
741+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
742+ output_token_ids [req_idx , pos ] = draft_token_id
743+
744+ if rejected :
745+ recovered_token_id = torch .argmax (
746+ target_probs [start_idx + last_accepted_token_pos + 1 ]).item ()
747+ output_token_ids [req_idx ,
748+ last_accepted_token_pos + 1 ] = recovered_token_id
749+ else :
750+ bonus_token_id = bonus_token_ids [req_idx ].item ()
751+ output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
752+
753+
507754rs .expand_batch_to_tokens = expand_batch_to_tokens
0 commit comments