Skip to content

Conversation

@tingtingtang1992
Copy link
Contributor

@tingtingtang1992 tingtingtang1992 commented Oct 26, 2025

Purpose

When sending mm inference req with prompt below to gemma3 (gemma3n as well), user didn't add placeholder string <start_of_image> themself.

{'role': 'system', 'content': [{'type': 'text', 'text': 'you are a helpful artist'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': 'Describe the images in detail.'}
after parse_chat_messages, the prompt becomes:

{'role': 'system', 'content': 'you are a helpful artist'}, {'role': 'user', 'content': '<start_of_image>\nDescribe the images in detail.'}
note the single \n with token id 107 is inserted

with vllm tries to replace mm placeholder tokens
[108, 255999, 262144 ... 262144, 256000, 108] with 255999, this is determined by logic here

then the token becomes [...262144, 262144, 256000, 108, 107...], according the logic here, token 108, 107 will be merged into 109

while in _find_mm_placeholders 109(\n\n\n) was replaced as 107(\n), 108(\n\n), so this will cause mm placeholder couldn't be found.

There might be better way to fix this, feel free to suggest, I can apply to gemma3n as well if looks good.

Test Plan

Runs locally, before this change, the above inference failed to find mm placeholder, after this change, request can go through

Test Result


Essential Elements of an Effective PR Description Checklist
  • fixing mm placeholder replacement issue with gemma3
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

When sending mm inference req with prompt below to gemma3 (gemma3n as well), user didn't add placeholder string <start_of_image> themself.

{'role': 'system', 'content': [{'type': 'text', 'text': 'you are a helpful artist'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': 'Describe the images in detail.'}

after parse_chat_messages, the prompt becomes:

{'role': 'system', 'content': 'you are a helpful artist'}, {'role': 'user', 'content': '<start_of_image>\nDescribe the images in detail.'}, note the single \n with token id 107 is inserted

with vllm tries to replace mm placeholder tokens 
[108, 255999, 262144 ... 262144, 256000, 108] with 255999, this is determined by logic here https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/gemma3/processing_gemma3.py#L71C106-L71C111 

then the token becomes [...262144, 262144, 256000, 108, 107...], according the logic here https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py#L376, token 108, 107 will be merged into 109

while in _find_mm_placeholders https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py#L406 109 was replaced as 107, 108, so this will cause mm placeholder couldn't be found.

There might be better way to fix this, feel free to suggest, I can apply to gemma3n as well if looks good.

Signed-off-by: tingtingtang1992 <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with multi-modal placeholder replacement for the gemma3 model. The issue occurs when newline tokens are merged ambiguously, causing placeholder detection to fail. The proposed change corrects the decomposition of a merged newline token for a specific case. However, my review found that this fix is incomplete and introduces a bug in another scenario. The root cause is an ambiguous merge operation that is not perfectly invertible. I've provided a critical comment with a suggestion for a more robust fix that addresses the ambiguity at its source, ensuring both cases are handled correctly.

def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
return [newline_2, newline_1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly handles the case where [newline_2, newline_1] is merged into newline_3. However, it breaks the case where [newline_1, newline_2] is merged, as the decomposition is no longer its inverse.

The root cause is the ambiguous merging in _apply_token_matches (lines 373-382) where both [newline_1, newline_2] and [newline_2, newline_1] are merged into newline_3. The decomposition here in _find_mm_placeholders can only be the inverse of one of these, making the transformation lossy and causing one of the cases to fail.

A more robust fix would be to make the merging unambiguous. For example, by only merging [newline_2, newline_1] into newline_3 in _apply_token_matches and removing the merge for [newline_1, newline_2].

Specifically, you could remove lines 373-377 in _apply_token_matches:

        token_ids = replace_token_matches(
            token_ids,
            [newline_1, newline_2],
            [newline_3],
        )

With that change, your change here to decompose newline_3 into [newline_2, newline_1] would be correct and not break other cases. Without changing _apply_token_matches, this PR trades one bug for another.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, let's see if the tests pass

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 27, 2025 02:42
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2025
@DarkLight1337 DarkLight1337 merged commit 23ad820 into vllm-project:main Oct 27, 2025
56 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants