1010 generate_uniform_probs )
1111from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1212
13+ try :
14+ import triton
15+ import triton .language as tl
16+
17+
18+ @triton .jit (do_not_specialize = ["max_spec_len" ])
19+ def rejection_greedy_sample_kernel (
20+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
21+ cu_num_draft_tokens_ptr , # [batch_size]
22+ draft_token_ids_ptr , # [num_tokens]
23+ target_argmax_ptr , # [num_tokens]
24+ bonus_token_ids_ptr , # [batch_size]
25+ is_greedy_ptr , # [batch_size] or None
26+ max_spec_len ,
27+ ):
28+ req_idx = tl .program_id (0 )
29+ # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
30+ # re-compilation may happen during runtime when is_greedy_ptr is None.
31+ is_greedy = True if is_greedy_ptr is None else tl .load (is_greedy_ptr +
32+ req_idx )
33+ if not is_greedy :
34+ # Early exit for non-greedy sampling requests.
35+ return
36+
37+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
38+ req_idx - 1 )
39+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
40+ num_draft_tokens = end_idx - start_idx
41+
42+ rejected = False
43+ for pos in range (num_draft_tokens ):
44+ if not rejected :
45+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
46+ target_argmax_id = tl .load (target_argmax_ptr + start_idx + pos )
47+ tl .store (
48+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
49+ target_argmax_id ,
50+ )
51+ if draft_token_id != target_argmax_id :
52+ # Reject.
53+ rejected = True
54+
55+ if not rejected :
56+ # If all tokens are accepted, append the bonus token.
57+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
58+ tl .store (
59+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
60+ num_draft_tokens ,
61+ bonus_token_id ,
62+ )
63+
64+
65+ @triton .jit (do_not_specialize = ["max_spec_len" ])
66+ def rejection_random_sample_block_verify_kernel (
67+ output_token_ids_ptr , # [batch_size, max_spec_len + 1]
68+ cu_num_draft_tokens_ptr , # [batch_size]
69+ draft_token_ids_ptr , # [num_tokens]
70+ draft_probs_ptr , # [num_tokens, vocab_size] or None
71+ target_probs_ptr , # [num_tokens, vocab_size]
72+ bonus_token_ids_ptr , # [batch_size]
73+ uniform_probs_ptr , # [num_tokens]
74+ is_greedy_ptr , # [batch_size]
75+ max_spec_len ,
76+ vocab_size ,
77+ NO_DRAFT_PROBS : tl .constexpr ,
78+ SUB_BLOCK : tl .constexpr = 1500 ,
79+ ):
80+ req_idx = tl .program_id (0 )
81+ is_greedy = tl .load (is_greedy_ptr + req_idx )
82+ if is_greedy :
83+ # Early exit for greedy sampling requests.
84+ return
85+
86+ start_idx = 0 if req_idx == 0 else tl .load (cu_num_draft_tokens_ptr +
87+ req_idx - 1 )
88+ end_idx = tl .load (cu_num_draft_tokens_ptr + req_idx )
89+ num_draft_tokens = end_idx - start_idx
90+
91+ rejected = False
92+ pi = 1.0
93+ uniform_prob = 1.0
94+ last_accepted_token_pos = - 1
95+
96+ for pos in range (num_draft_tokens ):
97+ draft_token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
98+ target_prob = tl .load (target_probs_ptr +
99+ (start_idx + pos ) * vocab_size +
100+ draft_token_id )
101+ tmp_uniform_prob = tl .load (uniform_probs_ptr + start_idx + pos )
102+ uniform_prob = uniform_prob * tmp_uniform_prob
103+
104+ if NO_DRAFT_PROBS :
105+ draft_prob = 1
106+ else :
107+ draft_prob = tl .load (draft_probs_ptr +
108+ (start_idx + pos ) * vocab_size +
109+ draft_token_id )
110+
111+ pi = min (pi * target_prob / draft_prob , 1.0 )
112+ if draft_prob > 0 and pi >= uniform_prob :
113+ last_accepted_token_pos = pos
114+ rejected = False
115+ else :
116+ rejected = True
117+
118+ if last_accepted_token_pos > - 1 :
119+ for pos in range (last_accepted_token_pos + 1 ):
120+ token_id = tl .load (draft_token_ids_ptr + start_idx + pos )
121+ tl .store (
122+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) + pos ,
123+ token_id )
124+
125+ if rejected :
126+ loop = (vocab_size + SUB_BLOCK - 1 ) // SUB_BLOCK
127+ global_recovered_id = - 1
128+ global_max_p = - 1.0
129+ for loop_i in range (loop ):
130+ vocab_start = loop_i * SUB_BLOCK
131+ vocab_offset = vocab_start + tl .arange (0 , SUB_BLOCK )
132+ tmp_target_prob = tl .load (
133+ target_probs_ptr +
134+ (start_idx + last_accepted_token_pos + 1 ) * vocab_size +
135+ vocab_offset ,
136+ mask = vocab_offset < vocab_size ,
137+ other = 0 )
138+ recovered_id = tl .argmax (tmp_target_prob , axis = - 1 )
139+ max_p = tl .get_element (tmp_target_prob , (recovered_id , ))
140+ if max_p > global_max_p :
141+ global_max_p = max_p
142+ global_recovered_id = vocab_start + recovered_id
143+ tl .store (
144+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
145+ last_accepted_token_pos + 1 , global_recovered_id )
146+ else :
147+ bonus_token_id = tl .load (bonus_token_ids_ptr + req_idx )
148+ tl .store (
149+ output_token_ids_ptr + req_idx * (max_spec_len + 1 ) +
150+ num_draft_tokens , bonus_token_id )
151+
152+
153+ TRITON_ASCEND_AVAILABLE = True
154+ except ImportError :
155+ TRITON_ASCEND_AVAILABLE = False
156+
13157PLACEHOLDER_TOKEN_ID = - 1
14158GREEDY_TEMPERATURE = - 1
15159# Maximum number of speculative draft tokens allowed per request in a single
@@ -134,6 +278,9 @@ def rejection_sample(
134278 assert bonus_token_ids .is_contiguous ()
135279 assert target_probs .shape == (num_tokens , vocab_size )
136280
281+ # Switch of Block Verify: when MTP>=3, using block verify for rejection sampler.
282+ using_block_verify = max_spec_len >= 3
283+
137284 # Create output buffer.
138285 output_token_ids = torch .empty (
139286 (batch_size , max_spec_len + 1 ),
@@ -149,25 +296,36 @@ def rejection_sample(
149296 if not sampling_metadata .all_random :
150297 # Rejection sampling for greedy sampling requests.
151298 target_argmax = target_probs .argmax (dim = - 1 )
152- if min (num_draft_tokens ) == 1 and max (
153- num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
154- rejection_greedy_sample_spec_len_1_pytorch (
155- output_token_ids ,
156- draft_token_ids ,
157- target_argmax ,
158- bonus_token_ids ,
159- )
160- else :
161- rejection_greedy_sample_pytorch (
299+ if TRITON_ASCEND_AVAILABLE :
300+ rejection_greedy_sample_kernel [(batch_size , )](
162301 output_token_ids ,
163302 cu_num_draft_tokens ,
164303 draft_token_ids ,
165304 target_argmax ,
166305 bonus_token_ids ,
167- num_draft_tokens ,
168- max_spec_len ,
169306 is_greedy ,
307+ max_spec_len ,
170308 )
309+ else :
310+ if min (num_draft_tokens ) == 1 and max (
311+ num_draft_tokens ) == 1 and sampling_metadata .all_greedy :
312+ rejection_greedy_sample_spec_len_1_pytorch (
313+ output_token_ids ,
314+ draft_token_ids ,
315+ target_argmax ,
316+ bonus_token_ids ,
317+ )
318+ else :
319+ rejection_greedy_sample_pytorch (
320+ output_token_ids ,
321+ cu_num_draft_tokens ,
322+ draft_token_ids ,
323+ target_argmax ,
324+ bonus_token_ids ,
325+ num_draft_tokens ,
326+ max_spec_len ,
327+ is_greedy ,
328+ )
171329 if sampling_metadata .all_greedy :
172330 return output_token_ids
173331
@@ -178,37 +336,68 @@ def rejection_sample(
178336 num_draft_tokens ,
179337 sampling_metadata .generators ,
180338 device ,
181- )
182-
183- # Sample recovered tokens for each position.
184- # [num_tokens]
185- recovered_token_ids = sample_recovered_tokens (
186- max_spec_len ,
187- num_draft_tokens ,
188- cu_num_draft_tokens ,
189- draft_token_ids ,
190- draft_probs ,
191- target_probs ,
192- sampling_metadata ,
193- device ,
194- )
339+ ).to (torch .float32 )
340+
341+ if not using_block_verify :
342+ # Sample recovered tokens for each position.
343+ # [num_tokens]
344+ recovered_token_ids = sample_recovered_tokens (
345+ max_spec_len ,
346+ num_draft_tokens ,
347+ cu_num_draft_tokens ,
348+ draft_token_ids ,
349+ draft_probs ,
350+ target_probs ,
351+ sampling_metadata ,
352+ device ,
353+ )
195354
196- # Rejection sampling for random sampling requests.
197- rejection_random_sample_pytorch (
198- output_token_ids ,
199- cu_num_draft_tokens ,
200- draft_token_ids ,
201- draft_probs ,
202- target_probs ,
203- bonus_token_ids ,
204- recovered_token_ids ,
205- uniform_probs ,
206- is_greedy ,
207- max_spec_len ,
208- vocab_size ,
209- IS_NGRAM = draft_probs is None ,
210- # num_warps=1,
211- )
355+ # Rejection sampling for random sampling requests.
356+ rejection_random_sample_pytorch (
357+ output_token_ids ,
358+ cu_num_draft_tokens ,
359+ draft_token_ids ,
360+ draft_probs ,
361+ target_probs ,
362+ bonus_token_ids ,
363+ recovered_token_ids ,
364+ uniform_probs ,
365+ is_greedy ,
366+ max_spec_len ,
367+ vocab_size ,
368+ IS_NGRAM = draft_probs is None ,
369+ # num_warps=1,
370+ )
371+ else :
372+ # MagicMTP: Improving acceptance rate with Block Verify.
373+ if TRITON_ASCEND_AVAILABLE :
374+ rejection_random_sample_block_verify_kernel [(batch_size , )](
375+ output_token_ids ,
376+ cu_num_draft_tokens ,
377+ draft_token_ids ,
378+ draft_probs ,
379+ target_probs ,
380+ bonus_token_ids ,
381+ uniform_probs ,
382+ is_greedy ,
383+ max_spec_len ,
384+ vocab_size ,
385+ NO_DRAFT_PROBS = draft_probs is None ,
386+ multibuffer = True ,
387+ )
388+ else :
389+ rejection_random_sample_block_verify_pytorch (output_token_ids ,
390+ cu_num_draft_tokens ,
391+ draft_token_ids ,
392+ draft_probs ,
393+ target_probs ,
394+ bonus_token_ids ,
395+ uniform_probs ,
396+ is_greedy ,
397+ max_spec_len ,
398+ vocab_size ,
399+ IS_NGRAM = draft_probs
400+ is None )
212401 return output_token_ids
213402
214403
@@ -504,4 +693,69 @@ def sample_recovered_tokens_pytorch(
504693 target_probs [token_idx , draft_token_id ] = orig_prob
505694
506695
696+ def rejection_random_sample_block_verify_pytorch (
697+ output_token_ids , # [batch_size, max_spec_len + 1]
698+ cu_num_draft_tokens , # [batch_size]
699+ draft_token_ids , # [num_tokens]
700+ draft_probs , # [num_tokens, vocab_size] or None
701+ target_probs , # [num_tokens, vocab_size]
702+ bonus_token_ids , # [batch_size]
703+ uniform_probs , # [num_tokens]
704+ is_greedy , # [batch_size]
705+ max_spec_len ,
706+ vocab_size ,
707+ IS_NGRAM = False ,
708+ ):
709+ batch_size = output_token_ids .shape [0 ]
710+
711+ for req_idx in range (batch_size ):
712+ if is_greedy [req_idx ]:
713+ continue
714+
715+ if req_idx == 0 :
716+ start_idx = 0
717+ else :
718+ start_idx = cu_num_draft_tokens [req_idx - 1 ].item ()
719+ end_idx = cu_num_draft_tokens [req_idx ].item ()
720+ num_draft_tokens = end_idx - start_idx
721+
722+ rejected = False
723+ pi = 1.0
724+ uniform_prob = 1.0
725+ last_accepted_token_pos = - 1
726+ for pos in range (num_draft_tokens ):
727+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
728+
729+ target_prob = target_probs [start_idx + pos , draft_token_id ].item ()
730+ uniform_prob = uniform_prob * uniform_probs [start_idx + pos ].item ()
731+
732+ if IS_NGRAM :
733+ draft_prob = 1.0
734+ else :
735+ draft_prob = draft_probs [start_idx + pos ,
736+ draft_token_id ].item ()
737+
738+ pi = min (pi * target_prob / draft_prob , 1.0 )
739+
740+ if draft_prob > 0 and pi >= uniform_prob :
741+ last_accepted_token_pos = pos
742+ rejected = False
743+ else :
744+ rejected = True
745+
746+ if last_accepted_token_pos > - 1 :
747+ for pos in range (last_accepted_token_pos + 1 ):
748+ draft_token_id = draft_token_ids [start_idx + pos ].item ()
749+ output_token_ids [req_idx , pos ] = draft_token_id
750+
751+ if rejected :
752+ recovered_token_id = torch .argmax (
753+ target_probs [start_idx + last_accepted_token_pos + 1 ]).item ()
754+ output_token_ids [req_idx ,
755+ last_accepted_token_pos + 1 ] = recovered_token_id
756+ else :
757+ bonus_token_id = bonus_token_ids [req_idx ].item ()
758+ output_token_ids [req_idx , num_draft_tokens ] = bonus_token_id
759+
760+
507761rs .expand_batch_to_tokens = expand_batch_to_tokens
0 commit comments