Skip to content

Conversation

@baonudesifeizhai
Copy link
Contributor

@baonudesifeizhai baonudesifeizhai commented Oct 25, 2025

Purpose

Fix issue #27486: Prevent duplicate BOS tokens when users manually apply chat templates before calling generate().

Problem

When users manually apply chat templates using tokenizer.apply_chat_template() and then call llm.generate() with the templated string, vLLM was incorrectly adding an additional BOS token during tokenization. This resulted in duplicate BOS tokens at the start of the prompt:

# User code
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = llm.generate([prompt])

# Expected: [128000, 128006, ...]  (1 BOS token)
# Actual:   [128000, 128000, 128006, ...]  (2 BOS tokens - WRONG!)

This happened because:

  1. apply_chat_template() generates text containing BOS token markers (e.g., <|begin_of_text|>)
  2. When vLLM tokenizes this string, it adds special tokens by default
  3. The tokenizer adds another BOS token, resulting in duplication

Solution

Added intelligent detection in InputPreprocessor._tokenize_prompt() to check if a prompt already contains special tokens before adding them:

Changes:

  1. New method _should_add_special_tokens() (vllm/inputs/preprocess.py)

    • Detects BOS token text at the start of prompts (e.g., <|begin_of_text|>)
    • Detects common chat template markers in the first 100 characters:
      • <|start_header_id|> (Llama 3.x)
      • <|im_start|> (ChatML format)
      • <s> (older models)
    • Returns False if special tokens are detected, True otherwise
  2. Modified _tokenize_prompt() (vllm/inputs/preprocess.py)

    • Calls _should_add_special_tokens() before tokenizing
    • Sets add_special_tokens=False when chat template markers are detected
    • Preserves backward compatibility - only affects prompts with detected markers
  3. New test case (tests/entrypoints/llm/test_chat.py)

    • test_llm_generate_with_chat_template_no_double_bos()
    • Verifies that manually applying chat templates doesn't cause duplicate BOS
    • Compares vLLM output with expected transformers tokenization

Test Plan

ytest tests/entrypoints/llm/test_chat.py::test_llm_generate_with_chat_template_no_double_bos -v -s pass


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@baonudesifeizhai baonudesifeizhai changed the title Fix issue 27486 double bos token Fix issue #27486 double bos token Oct 25, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with duplicate BOS tokens when a chat template is manually applied. The approach is to heuristically detect if special tokens are already present in the prompt. While the fix works for the intended cases, the implementation introduces a critical bug due to an overly broad heuristic for detecting the <s> token. My review includes a critical comment with a suggested code change to fix this bug by making the detection logic more robust and less prone to false positives.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 243 to 251
chat_markers = [
"<|start_header_id|>", # Llama 3.x
"<|im_start|>", # ChatML format
"<s>", # Common BOS in older models
]

for marker in chat_markers:
if marker in prompt[:100]: # Check first 100 chars
logger.debug(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid treating literal "" in user text as chat template

The new _should_add_special_tokens uses if marker in prompt[:100] with chat_markers containing <s>. Any plain prompt that mentions the literal string, e.g. asking about the HTML <s> tag, will now be classified as already containing special tokens and add_special_tokens will be forced to False. That removes the BOS/EOS tokens that were added before this change and can noticeably degrade generation quality for legitimate prompts that simply contain <s> in their content. The detection should be restricted to prefix checks or token boundary matches instead of substring search to avoid this regression.

Useful? React with 👍 / 👎.

@DarkLight1337 DarkLight1337 mentioned this pull request Oct 26, 2025
5 tasks
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @njhill I remember that a similar issue was filed a while back?

@DarkLight1337
Copy link
Member

See #9519

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants