Skip to content

Commit

Permalink
update repetition penalties
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 12, 2024
1 parent f64eb02 commit d1cc27a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,11 @@ def adjust_logits(
# in the right order.
if apply_penalty:
repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size)
logits = torch.where(
logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t
)
# RepetitionPenaltyLogitsProcessor approach from HF TGI API is used
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22
# where score is logits
# https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92
logits = logits / repetition_penalties_t
bin_counts = torch.zeros(
(batch_size, vocab_size + 1), dtype=torch.long, device=logits.device
)
Expand Down

0 comments on commit d1cc27a

Please sign in to comment.