Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/entrypoints/llm/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,64 @@ def test_llm_chat_tokenization_no_double_bos(text_llm):
assert prompt_token_ids[1] != bos_token, "Double BOS"


def test_llm_generate_with_chat_template_no_double_bos(text_llm):
"""
Test for issue #27486: When using apply_chat_template manually
and then calling generate(), should not duplicate BOS token.

Note: This test reuses the text_llm fixture's tokenizer to avoid
GPU resource conflicts when running the full test suite.
"""
from vllm import SamplingParams

# Reuse the tokenizer from the existing LLM instance
tokenizer = text_llm.get_tokenizer()

messages = [{"role": "user", "content": "Hello, how are you?"}]

# Apply chat template manually (as users might do)
prompt_with_template = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Get expected token IDs from transformers
expected_token_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)

# Generate using vLLM with the templated string
outputs = text_llm.generate(
[prompt_with_template],
sampling_params=SamplingParams(temperature=0.0, max_tokens=1),
)

assert len(outputs) == 1
prompt_token_ids = outputs[0].prompt_token_ids
assert prompt_token_ids is not None

# Check that vLLM produces the same token IDs as transformers
assert len(prompt_token_ids) == len(expected_token_ids), (
f"Length mismatch: vLLM has {len(prompt_token_ids)} tokens, "
f"expected {len(expected_token_ids)}"
)

# Verify no duplicate BOS at the start
bos_token = tokenizer.bos_token_id
assert prompt_token_ids[0] == bos_token, "First token should be BOS"
assert prompt_token_ids[1] != bos_token, (
"Second token should not be BOS (no duplication)"
)

# Verify exact match
if prompt_token_ids != expected_token_ids:
mismatch_idx = next(
i
for i, (a, b) in enumerate(zip(prompt_token_ids, expected_token_ids))
if a != b
)
raise AssertionError(f"Token mismatch at index {mismatch_idx}")


@pytest.fixture(scope="function")
def thinking_llm():
# pytest caches the fixture so we use weakref.proxy to
Expand Down
58 changes: 58 additions & 0 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,57 @@ def _get_tokenization_kw(

return kwargs

def _should_add_special_tokens(self, prompt: str) -> bool:
"""
Determine whether to add special tokens when tokenizing a prompt.

Returns False if the prompt appears to already contain special tokens
(e.g., from chat template application), True otherwise.

This helps avoid duplicating BOS/EOS tokens when users manually apply
chat templates before calling generate().

See: https://github.com/vllm-project/vllm/issues/27486
"""
if self.tokenizer is None:
return True

# Check if tokenizer has a BOS token
bos_token = getattr(self.tokenizer, "bos_token", None)
if not bos_token:
return True

# If prompt starts with BOS token text, don't add special tokens
# This handles cases like Llama's "<|begin_of_text|>" and "<s>"
# Use lstrip() to handle prompts with leading whitespace
if prompt.lstrip().startswith(bos_token):
logger.debug(
"Detected BOS token at the start of prompt. "
"Setting add_special_tokens=False to avoid duplication."
)
return False

# Check for common chat template markers that indicate special tokens
# are already present (e.g., "<|start_header_id|>", "<|im_start|>").
# Note: "<s>" is intentionally not included here as it is too generic
# and can cause false positives (e.g., in "<script>" tags). The check
# above handles "<s>" when it's a BOS token via startswith().
chat_markers = [
"<|start_header_id|>", # Llama 3.x
"<|im_start|>", # ChatML format
]

for marker in chat_markers:
if marker in prompt[:100]: # Check first 100 chars
logger.debug(
"Detected chat template marker '%s' in prompt. "
"Setting add_special_tokens=False to avoid duplication.",
marker,
)
return False

return True

def _tokenize_prompt(
self,
prompt: str,
Expand All @@ -221,6 +272,13 @@ def _tokenize_prompt(
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)

# Intelligently detect if special tokens should be added
# This prevents double BOS tokens when using chat templates
if "add_special_tokens" not in tokenization_kwargs:
tokenization_kwargs["add_special_tokens"] = self._should_add_special_tokens(
prompt
)

encoder_config = self.model_config.encoder_config

if encoder_config and encoder_config.get("do_lower_case", False):
Expand Down