@@ -53,6 +53,9 @@ class SamplingTensors:
5353 mask_top_logprob: torch.Tensor
5454 Mask for requests with top_logprob.
5555 shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,)
56+ mask_prompt: torch.Tensor
57+ Mask for request with repetition penalty (prompt part)
58+ shape: (batch_size, vocab_size)
5659 temperatures: torch.Tensor
5760 Tensor for temperature values
5861 shape: (batch_size, )
@@ -85,6 +88,7 @@ class SamplingTensors:
8588 mask_random : torch .Tensor
8689 mask_greedy : torch .Tensor
8790 mask_top_logprob : torch .Tensor
91+ mask_prompt : torch .Tensor
8892 temperatures : torch .Tensor
8993 top_ps : torch .Tensor
9094 top_ks : torch .Tensor
@@ -102,6 +106,7 @@ def from_lists(
102106 dev ,
103107 list_mask_random : List [bool ],
104108 list_mask_top_logprob : List [List [bool ]],
109+ list_mask_prompt : List [torch .Tensor ],
105110 list_temperatures : List [float ],
106111 list_top_ps : List [float ],
107112 list_top_ks : List [int ],
@@ -124,6 +129,7 @@ def from_lists(
124129 )
125130 # `mask_top_logprob` will be on cpu
126131 mask_top_logprob = torch .from_numpy (list_mask_top_logprob )
132+ mask_prompt = torch .stack (list_mask_prompt )
127133 temp = torch .tensor (
128134 list_temperatures ,
129135 dtype = dtype ,
@@ -185,6 +191,7 @@ def from_lists(
185191 mask_random ,
186192 mask_greedy ,
187193 mask_top_logprob ,
194+ mask_prompt ,
188195 temp .to (device = dev , non_blocking = True ),
189196 top_ps .to (device = dev , non_blocking = True ),
190197 top_ks .to (device = dev , non_blocking = True ),
@@ -250,6 +257,7 @@ def from_sampling_params(
250257 vocab_size : int ,
251258 ):
252259 list_mask_random = []
260+ list_mask_prompt = []
253261 list_temperatures = []
254262 list_top_ps = []
255263 list_top_ks = []
@@ -307,6 +315,7 @@ def from_sampling_params(
307315 list_frequency_penalties .append (param .frequency_penalty )
308316 list_presence_penalties .append (param .presence_penalty )
309317 list_repetition_penalties .append (param .repetition_penalty )
318+ list_mask_prompt .append (param .mask_prompt )
310319
311320 if param .logit_bias_index :
312321 assert param .logit_bias_value
@@ -348,6 +357,7 @@ def from_sampling_params(
348357 dev ,
349358 list_mask_random ,
350359 list_mask_top_logprob ,
360+ list_mask_prompt ,
351361 list_temperatures ,
352362 list_top_ps ,
353363 list_top_ks ,
@@ -372,20 +382,39 @@ def from_sampling_params(
372382 )
373383
374384
375- def adjust_logits (logits , sampling_metadata , vocab_size ):
385+ def get_bin_counts_and_mask (
386+ tokens : torch .Tensor ,
387+ vocab_size : int ,
388+ num_seqs : int ,
389+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
390+ bin_counts = torch .zeros ((num_seqs , vocab_size + 1 ),
391+ dtype = torch .long ,
392+ device = tokens .device )
393+ bin_counts .scatter_add_ (1 , tokens , torch .ones_like (tokens ))
394+ bin_counts = bin_counts [:, :vocab_size ]
395+ mask = bin_counts > 0
396+
397+ return bin_counts , mask
398+
399+
400+ def adjust_logits (
401+ logits : torch .Tensor ,
402+ sampling_state : SamplingState ,
403+ vocab_size : int ):
376404 batch_size = logits .shape [0 ]
377405 (
378406 apply_top_p_top_k ,
379407 apply_penalty ,
380408 apply_bias ,
381409 sampling_tensors ,
382410 ) = (
383- sampling_metadata .apply_top_p_top_k ,
384- sampling_metadata .apply_penalty ,
385- sampling_metadata .apply_bias ,
386- sampling_metadata .sampling_tensors ,
411+ sampling_state .apply_top_p_top_k ,
412+ sampling_state .apply_penalty ,
413+ sampling_state .apply_bias ,
414+ sampling_state .sampling_tensors ,
387415 )
388416 (
417+ prompt_mask ,
389418 temp_t ,
390419 top_ps_t ,
391420 top_ks_t ,
@@ -396,6 +425,7 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
396425 logit_bias_indices_t ,
397426 logit_bias_values_t ,
398427 ) = (
428+ sampling_tensors .mask_prompt ,
399429 sampling_tensors .temperatures ,
400430 sampling_tensors .top_ps ,
401431 sampling_tensors .top_ks ,
@@ -411,20 +441,30 @@ def adjust_logits(logits, sampling_metadata, vocab_size):
411441 # (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...)
412442 # in the right order.
413443 if apply_penalty :
444+ bin_counts , output_mask = get_bin_counts_and_mask (
445+ past_output_tokens_t ,
446+ vocab_size ,
447+ batch_size ,
448+ )
449+
450+ # It was checked that vLLM and HF approaches for repetition penalty are the same
451+ # For calculation of it their combination is used (see references below)
452+ # Calculate repetition penalty use vLLM approach
453+ # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177
454+ # and RepetitionPenaltyLogitsProcessor approach from HF TGI API
455+ # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22
456+ # where score is logits
457+ # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92
414458 repetition_penalties_t = repetition_penalties_t [:, None ].repeat (1 , vocab_size )
459+ prompt_mask = prompt_mask .to (repetition_penalties_t .device )
460+ repetition_penalties_t [~ (prompt_mask | output_mask )] = 1.0
415461 logits = torch .where (
416462 logits > 0 , logits / repetition_penalties_t , logits * repetition_penalties_t
417463 )
418- bin_counts = torch .zeros (
419- (batch_size , vocab_size + 1 ), dtype = torch .long , device = logits .device
420- )
421- bin_counts .scatter_add_ (
422- 1 , past_output_tokens_t , torch .ones_like (past_output_tokens_t )
423- )
424- bin_counts = bin_counts [:, :vocab_size ]
425- mask = bin_counts > 0
464+
465+ # Calculate frequency and presence penalties
426466 logits -= frequency_penalties_t .unsqueeze_ (dim = 1 ) * bin_counts
427- logits -= presence_penalties_t .unsqueeze_ (dim = 1 ) * mask
467+ logits -= presence_penalties_t .unsqueeze_ (dim = 1 ) * output_mask
428468
429469 # Adjust temperature
430470 logits .div_ (temp_t .unsqueeze (dim = 1 ))
@@ -447,7 +487,7 @@ class SamplingOutput:
447487
448488def sample (
449489 logits : torch .Tensor ,
450- sampling_metadata : SamplingState ,
490+ sampling_state : SamplingState ,
451491 check_safety : bool = False ,
452492) -> SamplingOutput :
453493 def _is_safe_to_sample (prob_like ):
@@ -457,7 +497,7 @@ def _is_safe_to_sample(prob_like):
457497 )
458498
459499 res_greedy , res_random = None , None
460- sampling_tensors = sampling_metadata .sampling_tensors
500+ sampling_tensors = sampling_state .sampling_tensors
461501
462502 batch_size = logits .shape [0 ]
463503 mask_greedy_t , mask_random_t = (
@@ -466,13 +506,13 @@ def _is_safe_to_sample(prob_like):
466506 )
467507
468508 next_tokens = np .empty ((batch_size ,), dtype = np .int64 )
469- if sampling_metadata .has_greedy :
509+ if sampling_state .has_greedy :
470510 res_greedy = torch .argmax (logits [mask_greedy_t ], - 1 )
471511 np_mask_greedy = mask_greedy_t .cpu ().numpy ()
472512 next_tokens [np_mask_greedy ] = res_greedy .cpu ().numpy ()
473513
474514 probs_random = None
475- if sampling_metadata .has_random :
515+ if sampling_state .has_random :
476516 probs_random = torch .softmax (logits [mask_random_t ], dim = - 1 )
477517 if check_safety and not _is_safe_to_sample (probs_random ):
478518 return None
@@ -481,9 +521,9 @@ def _is_safe_to_sample(prob_like):
481521 next_tokens [np_mask_random ] = res_random .cpu ().numpy ()
482522
483523 logprob_infos : List [Optional [RawLogprobsInfo ]] = [None ] * batch_size
484- if sampling_metadata .has_logprob :
524+ if sampling_state .has_logprob :
485525 # If everything is random sampling, save one extra softmax
486- if not sampling_metadata .has_greedy :
526+ if not sampling_state .has_greedy :
487527 assert probs_random is not None
488528 logprobs = torch .log (probs_random )
489529 else :
@@ -494,13 +534,13 @@ def _is_safe_to_sample(prob_like):
494534 all_top_logprobs , all_top_tokens = torch .topk (
495535 extended_logprobs , k = LOGPROB_TOP_K_MAX , dim = - 1 , largest = True , sorted = True
496536 )
497- mask = sampling_metadata .sampling_tensors .mask_top_logprob
537+ mask = sampling_state .sampling_tensors .mask_top_logprob
498538 top_tokens = all_top_tokens [mask ]
499539 top_logprobs = all_top_logprobs [mask ]
500- for idx , batch_idx in enumerate (sampling_metadata .logprob_batch_indices ):
540+ for idx , batch_idx in enumerate (sampling_state .logprob_batch_indices ):
501541 next_token = next_tokens [batch_idx ]
502- assert sampling_metadata .sampling_params [batch_idx ].logprobs
503- top_k = sampling_metadata .sampling_params [batch_idx ].top_logprobs
542+ assert sampling_state .sampling_params [batch_idx ].logprobs
543+ top_k = sampling_state .sampling_params [batch_idx ].top_logprobs
504544 logprob_infos [batch_idx ] = RawLogprobsInfo (
505545 current_token_id = next_token ,
506546 current_logprob = logprobs [batch_idx ][next_token ],
0 commit comments