Skip to content

Commit d847227

Browse files
committed
precommit
Signed-off-by: 0xrushi <[email protected]>
1 parent a47979b commit d847227

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

tests/test_inputs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,40 @@ def test_preprocessor_always_mm_code_path(model_id, prompt):
135135

136136
processed_inputs = input_preprocessor.preprocess(prompt)
137137
assert sep_token_id in processed_inputs["prompt_token_ids"]
138+
139+
140+
def _get_bos_prefixed_prompt_or_skip(tokenizer):
141+
bos_token = getattr(tokenizer, "bos_token", None)
142+
if not bos_token or not isinstance(bos_token, str):
143+
pytest.skip("Tokenizer has no string bos_token to test BOS handling.")
144+
return f"{bos_token} Hello world"
145+
146+
147+
@pytest.mark.parametrize(
148+
"explicit_add_special",
149+
[True, None],
150+
)
151+
def test_double_bos_token(monkeypatch, explicit_add_special):
152+
model_config = ModelConfig(model="facebook/opt-125m")
153+
input_preprocessor = InputPreprocessor(model_config)
154+
155+
tokenizer = input_preprocessor.get_tokenizer()
156+
prompt = _get_bos_prefixed_prompt_or_skip(tokenizer)
157+
158+
captured: dict[str, object] = {}
159+
160+
def fake_encode(text, **kwargs):
161+
captured["kwargs"] = dict(kwargs)
162+
# dummy
163+
return [101, 102, 103]
164+
165+
monkeypatch.setattr(tokenizer, "encode", fake_encode, raising=True)
166+
167+
if explicit_add_special is True:
168+
_ = input_preprocessor._tokenize_prompt(
169+
prompt, tokenization_kwargs={"add_special_tokens": True}
170+
)
171+
assert captured["kwargs"].get("add_special_tokens") is True
172+
else:
173+
_ = input_preprocessor._tokenize_prompt(prompt)
174+
assert captured["kwargs"].get("add_special_tokens") is False

vllm/inputs/preprocess.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ def _tokenize_prompt(
221221
tokenizer = self.get_tokenizer()
222222
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
223223

224-
225224
bos_token_text = getattr(tokenizer, "bos_token", None)
226225
if bos_token_text and isinstance(bos_token_text, str):
227226
if prompt.lstrip().startswith(bos_token_text):

0 commit comments

Comments
 (0)