Skip to content

Commit

Permalink
[Fix] Repetition penalty calculation fix (#2937)
Browse files Browse the repository at this point in the history
Repetition penalty compilation fix.
  • Loading branch information
shtinsa authored Sep 27, 2024
1 parent 268d52a commit fe79e9a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/mlc_llm/compiler_pass/attach_logit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
)
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] > 0,
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < 0,
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]
* penalties[pos2seq_id[vp], 2],
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]
Expand Down

0 comments on commit fe79e9a

Please sign in to comment.