1010 generate_uniform_probs )
1111from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1212
13+ try :
14+ import triton
15+ import triton .language as tl
16+
17+ @triton .jit (do_not_specialize = ["max_spec_len" ])
18+ def rejection_greedy_sample_kernel (
19+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
20+ cu_num_draft_tokens_ptr , # [batch_size]
21+ draft_token_ids_ptr , # [num_tokens]
22+ target_argmax_ptr , # [num_tokens]
23+ bonus_token_ids_ptr , # [batch_size]
24+ is_greedy_ptr , # [batch_size] or None
25+ max_spec_len ,
26+ ):
27+ req_idx = tl .program_id (0 )
28+ # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
29+ # re-compilation may happen during runtime when is_greedy_ptr is None.
30+ is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr +
31+ req_idx )
32+ if not is_greedy :
33+ # Early exit for non-greedy sampling requests.
34+ return
35+
36+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
37+ req_idx - 1 )
38+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
39+ num_draft_tokens = end_idx - start_idx
40+
41+ rejected = False
42+ for pos in range (num_draft_tokens ):
43+ if not rejected :
44+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
45+ target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
46+ tl .store (
47+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
48+ target_argmax_id ,
49+ )
50+ if draft_token_id != target_argmax_id :
51+ # Reject.
52+ rejected = True
53+
54+ if not rejected :
55+ # If all tokens are accepted, append the bonus token.
56+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
57+ tl .store (
58+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
59+ num_draft_tokens ,
60+ bonus_token_id ,
61+ )
62+
63+ @triton .jit (do_not_specialize = ["max_spec_len" ])
64+ def rejection_random_sample_block_verify_kernel (
65+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
66+ cu_num_draft_tokens_ptr , # [batch_size]
67+ draft_token_ids_ptr , # [num_tokens]
68+ draft_probs_ptr , # [num_tokens, vocab_size] or None
69+ target_probs_ptr , # [num_tokens, vocab_size]
70+ bonus_token_ids_ptr , # [batch_size]
71+ uniform_probs_ptr , # [num_tokens]
72+ is_greedy_ptr , # [batch_size]
73+ max_spec_len ,
74+ vocab_size ,
75+ NO_DRAFT_PROBS : tl .constexpr ,
76+ SUB_BLOCK : tl .constexpr = 1500 ,
77+ ):
78+ req_idx = tl .program_id (0 )
79+ is_greedy = tl .load (is_greedy_ptr + req_idx )
80+ if is_greedy :
81+ # Early exit for greedy sampling requests.
82+ return
83+
84+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
85+ req_idx - 1 )
86+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
87+ num_draft_tokens = end_idx - start_idx
88+
89+ rejected = False
90+ pi = 1.0
91+ uniform_prob = 1.0
92+ last_accepted_token_pos = - 1
93+
94+ for pos in range (num_draft_tokens ):
95+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
96+ target_prob = tl .load (target_probs_ptr +
97+ (start_idx + pos ) * vocab_size +
98+ draft_token_id )
99+ tmp_uniform_prob = tl .load (uniform_probs_ptr + start_idx + pos )
100+ uniform_prob = uniform_prob * tmp_uniform_prob
101+
102+ if NO_DRAFT_PROBS :
103+ draft_prob = 1
104+ else :
105+ draft_prob = tl .load (draft_probs_ptr +
106+ (start_idx + pos ) * vocab_size +
107+ draft_token_id )
108+
109+ pi = min (pi * target_prob / draft_prob , 1.0 )
110+ if draft_prob > 0 and pi >= uniform_prob :
111+ last_accepted_token_pos = pos
112+ rejected = False
113+ else :
114+ rejected = True
115+
116+ if last_accepted_token_pos > - 1 :
117+ for pos in range (last_accepted_token_pos + 1 ):
118+ token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
119+ tl .store (
120+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
121+ token_id )
122+
123+ if rejected :
124+ loop = (vocab_size + SUB_BLOCK - 1 ) // SUB_BLOCK
125+ global_recovered_id = - 1
126+ global_max_p = - 1.0
127+ for loop_i in range (loop ):
128+ vocab_start = loop_i * SUB_BLOCK
129+ vocab_offset = vocab_start + tl .arange (0 , SUB_BLOCK )
130+ tmp_target_prob = tl .load (
131+ target_probs_ptr +
132+ (start_idx + last_accepted_token_pos + 1 ) * vocab_size +
133+ vocab_offset ,
134+ mask = vocab_offset < vocab_size ,
135+ other = 0 )
136+ recovered_id = tl .argmax (tmp_target_prob , axis = - 1 )
137+ max_p = tl .get_element (tmp_target_prob , (recovered_id , ))
138+ if max_p > global_max_p :
139+ global_max_p = max_p
140+ global_recovered_id = vocab_start + recovered_id
141+ tl .store (
142+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
143+ last_accepted_token_pos + 1 , global_recovered_id )
144+ else :
145+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
146+ tl .store (
147+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
148+ num_draft_tokens , bonus_token_id )
149+
150+ TRITON_ASCEND_AVAILABLE = True
151+ except ImportError :
152+ TRITON_ASCEND_AVAILABLE = False
153+
13154PLACEHOLDER_TOKEN_ID = - 1
14155GREEDY_TEMPERATURE = - 1
15156# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +275,9 @@ def rejection_sample(
134275 assert bonus_token_ids .is_contiguous ()
135276 assert target_probs .shape == (num_tokens , vocab_size )
136277
278+ # Switch of Block Verify: when MTP>=3, using block verify for rejection sampler.
279+ using_block_verify = max_spec_len >= 3
280+
137281 # Create output buffer.
138282 output_token_ids = torch .empty (
139283 (batch_size , max_spec_len + 1 ),
@@ -149,25 +293,36 @@ def rejection_sample(
149293 if not sampling_metadata .all_random :
150294 # Rejection sampling for greedy sampling requests.
151295 target_argmax = target_probs .argmax (dim = - 1 )
152- if min (num_draft_tokens ) == 1 and max (
153- num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
154- rejection_greedy_sample_spec_len_1_pytorch (
155- output_token_ids ,
156- draft_token_ids ,
157- target_argmax ,
158- bonus_token_ids ,
159- )
160- else :
161- rejection_greedy_sample_pytorch (
296+ if TRITON_ASCEND_AVAILABLE :
297+ rejection_greedy_sample_kernel [(batch_size , )](
162298 output_token_ids ,
163299 cu_num_draft_tokens ,
164300 draft_token_ids ,
165301 target_argmax ,
166302 bonus_token_ids ,
167- num_draft_tokens ,
168- max_spec_len ,
169303 is_greedy ,
304+ max_spec_len ,
170305 )
306+ else :
307+ if min (num_draft_tokens ) == 1 and max (
308+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
309+ rejection_greedy_sample_spec_len_1_pytorch (
310+ output_token_ids ,
311+ draft_token_ids ,
312+ target_argmax ,
313+ bonus_token_ids ,
314+ )
315+ else :
316+ rejection_greedy_sample_pytorch (
317+ output_token_ids ,
318+ cu_num_draft_tokens ,
319+ draft_token_ids ,
320+ target_argmax ,
321+ bonus_token_ids ,
322+ num_draft_tokens ,
323+ max_spec_len ,
324+ is_greedy ,
325+ )
171326 if sampling_metadata .all_greedy :
172327 return output_token_ids
173328
@@ -178,37 +333,68 @@ def rejection_sample(
178333 num_draft_tokens ,
179334 sampling_metadata .generators ,
180335 device ,
181- )
182-
183- # Sample recovered tokens for each position.
184- # [num_tokens]
185- recovered_token_ids = sample_recovered_tokens (
186- max_spec_len ,
187- num_draft_tokens ,
188- cu_num_draft_tokens ,
189- draft_token_ids ,
190- draft_probs ,
191- target_probs ,
192- sampling_metadata ,
193- device ,
194- )
336+ ).to (torch .float32 )
337+
338+ if not using_block_verify :
339+ # Sample recovered tokens for each position.
340+ # [num_tokens]
341+ recovered_token_ids = sample_recovered_tokens (
342+ max_spec_len ,
343+ num_draft_tokens ,
344+ cu_num_draft_tokens ,
345+ draft_token_ids ,
346+ draft_probs ,
347+ target_probs ,
348+ sampling_metadata ,
349+ device ,
350+ )
195351
196- # Rejection sampling for random sampling requests.
197- rejection_random_sample_pytorch (
198- output_token_ids ,
199- cu_num_draft_tokens ,
200- draft_token_ids ,
201- draft_probs ,
202- target_probs ,
203- bonus_token_ids ,
204- recovered_token_ids ,
205- uniform_probs ,
206- is_greedy ,
207- max_spec_len ,
208- vocab_size ,
209- IS_NGRAM = draft_probs is None ,
210- # num_warps=1,
211- )
352+ # Rejection sampling for random sampling requests.
353+ rejection_random_sample_pytorch (
354+ output_token_ids ,
355+ cu_num_draft_tokens ,
356+ draft_token_ids ,
357+ draft_probs ,
358+ target_probs ,
359+ bonus_token_ids ,
360+ recovered_token_ids ,
361+ uniform_probs ,
362+ is_greedy ,
363+ max_spec_len ,
364+ vocab_size ,
365+ IS_NGRAM = draft_probs is None ,
366+ # num_warps=1,
367+ )
368+ else :
369+ # MagicMTP: Improving acceptance rate with Block Verify.
370+ if TRITON_ASCEND_AVAILABLE :
371+ rejection_random_sample_block_verify_kernel [(batch_size , )](
372+ output_token_ids ,
373+ cu_num_draft_tokens ,
374+ draft_token_ids ,
375+ draft_probs ,
376+ target_probs ,
377+ bonus_token_ids ,
378+ uniform_probs ,
379+ is_greedy ,
380+ max_spec_len ,
381+ vocab_size ,
382+ NO_DRAFT_PROBS = draft_probs is None ,
383+ multibuffer = True ,
384+ )
385+ else :
386+ rejection_random_sample_block_verify_pytorch (output_token_ids ,
387+ cu_num_draft_tokens ,
388+ draft_token_ids ,
389+ draft_probs ,
390+ target_probs ,
391+ bonus_token_ids ,
392+ uniform_probs ,
393+ is_greedy ,
394+ max_spec_len ,
395+ vocab_size ,
396+ IS_NGRAM = draft_probs
397+ is None )
212398 return output_token_ids
213399
214400
@@ -504,4 +690,69 @@ def sample_recovered_tokens_pytorch(
504690 target_probs [token_idx , draft_token_id ] = orig_prob
505691
506692
693+ def rejection_random_sample_block_verify_pytorch (
694+ output_token_ids , # [batch_size, max_spec_len + 1]
695+ cu_num_draft_tokens , # [batch_size]
696+ draft_token_ids , # [num_tokens]
697+ draft_probs , # [num_tokens, vocab_size] or None
698+ target_probs , # [num_tokens, vocab_size]
699+ bonus_token_ids , # [batch_size]
700+ uniform_probs , # [num_tokens]
701+ is_greedy , # [batch_size]
702+ max_spec_len ,
703+ vocab_size ,
704+ IS_NGRAM = False ,
705+ ):
706+ batch_size = output_token_ids .shape [0 ]
707+
708+ for req_idx in range (batch_size ):
709+ if is_greedy [req_idx ]:
710+ continue
711+
712+ if req_idx == 0 :
713+ start_idx = 0
714+ else :
715+ start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
716+ end_idx = cu_num_draft_tokens [req_idx ].item ()
717+ num_draft_tokens = end_idx - start_idx
718+
719+ rejected = False
720+ pi = 1.0
721+ uniform_prob = 1.0
722+ last_accepted_token_pos = - 1
723+ for pos in range (num_draft_tokens ):
724+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
725+
726+ target_prob = target_probs [start_idx + pos , draft_token_id ].item ()
727+ uniform_prob = uniform_prob * uniform_probs [start_idx + pos ].item ()
728+
729+ if IS_NGRAM :
730+ draft_prob = 1.0
731+ else :
732+ draft_prob = draft_probs [start_idx + pos ,
733+ draft_token_id ].item ()
734+
735+ pi = min (pi * target_prob / draft_prob , 1.0 )
736+
737+ if draft_prob > 0 and pi >= uniform_prob :
738+ last_accepted_token_pos = pos
739+ rejected = False
740+ else :
741+ rejected = True
742+
743+ if last_accepted_token_pos > - 1 :
744+ for pos in range (last_accepted_token_pos + 1 ):
745+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
746+ output_token_ids [req_idx , pos ] = draft_token_id
747+
748+ if rejected :
749+ recovered_token_id = torch .argmax (
750+ target_probs [start_idx + last_accepted_token_pos + 1 ]).item ()
751+ output_token_ids [req_idx ,
752+ last_accepted_token_pos + 1 ] = recovered_token_id
753+ else :
754+ bonus_token_id = bonus_token_ids [req_idx ].item ()
755+ output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
756+
757+
507758rs .expand_batch_to_tokens = expand_batch_to_tokens
0 commit comments