Skip to content

Commit c92c761

Browse files
WorldExploredvnadathurgemini-code-assist[bot]noooop
authored andcommitted
[Frontend] Added chat-style multimodal support to /classify. (vllm-project#27516)
Signed-off-by: WorldExplored <[email protected]> Signed-off-by: Srreyansh Sethi <[email protected]> Signed-off-by: vnadathur <[email protected]> Signed-off-by: wang.yuqi <[email protected]> Co-authored-by: vnadathur <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: vnadathur <[email protected]> Co-authored-by: wang.yuqi <[email protected]> Co-authored-by: wang.yuqi <[email protected]> Signed-off-by: George D. Torres <[email protected]>
1 parent f42888b commit c92c761

File tree

6 files changed

+318
-27
lines changed

6 files changed

+318
-27
lines changed

tests/entrypoints/pooling/openai/test_classification.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
4646
assert hasattr(output.data[0], "probs")
4747

4848

49+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
50+
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
51+
response = requests.post(
52+
server.url_for("classify"),
53+
json={"model": model_name, "input": "hello", "add_special_tokens": False},
54+
)
55+
response.raise_for_status()
56+
ClassificationResponse.model_validate(response.json())
57+
58+
4959
@pytest.mark.parametrize("model_name", [MODEL_NAME])
5060
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
5161
input_texts = [
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
6+
import pytest
7+
import requests
8+
9+
from tests.utils import RemoteOpenAIServer
10+
from vllm.entrypoints.openai.protocol import ClassificationResponse
11+
12+
VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
13+
MAXIMUM_VIDEOS = 1
14+
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"
15+
16+
HF_OVERRIDES = {
17+
"text_config": {
18+
"architectures": ["Qwen2_5_VLForSequenceClassification"],
19+
},
20+
}
21+
22+
23+
@pytest.fixture(scope="module")
24+
def server_vlm_classify():
25+
args = [
26+
"--runner",
27+
"pooling",
28+
"--max-model-len",
29+
"5000",
30+
"--enforce-eager",
31+
"--limit-mm-per-prompt",
32+
json.dumps({"video": MAXIMUM_VIDEOS}),
33+
]
34+
35+
with RemoteOpenAIServer(
36+
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
37+
) as remote_server:
38+
yield remote_server
39+
40+
41+
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
42+
def test_classify_accepts_chat_text_only(
43+
server_vlm_classify: RemoteOpenAIServer, model_name: str
44+
) -> None:
45+
messages = [
46+
{
47+
"role": "user",
48+
"content": [
49+
{"type": "text", "text": "Please classify this text request."},
50+
],
51+
}
52+
]
53+
54+
response = requests.post(
55+
server_vlm_classify.url_for("classify"),
56+
json={"model": model_name, "messages": messages},
57+
)
58+
response.raise_for_status()
59+
60+
output = ClassificationResponse.model_validate(response.json())
61+
62+
assert output.object == "list"
63+
assert output.model == model_name
64+
assert len(output.data) == 1
65+
assert len(output.data[0].probs) == 2
66+
assert output.usage.prompt_tokens == 22
67+
68+
69+
@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
70+
def test_classify_accepts_chat_video_url(
71+
server_vlm_classify: RemoteOpenAIServer, model_name: str
72+
) -> None:
73+
messages = [
74+
{
75+
"role": "user",
76+
"content": [
77+
{"type": "text", "text": "Please classify this video."},
78+
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
79+
],
80+
}
81+
]
82+
83+
response = requests.post(
84+
server_vlm_classify.url_for("classify"),
85+
json={"model": model_name, "messages": messages},
86+
)
87+
response.raise_for_status()
88+
89+
output = ClassificationResponse.model_validate(response.json())
90+
91+
assert output.object == "list"
92+
assert output.model == model_name
93+
assert len(output.data) == 1
94+
assert len(output.data[0].probs) == 2
95+
assert output.usage.prompt_tokens == 4807

vllm/entrypoints/openai/api_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,9 @@ async def init_app_state(
17841784
engine_client,
17851785
state.openai_serving_models,
17861786
request_logger=request_logger,
1787+
chat_template=resolved_chat_template,
1788+
chat_template_content_format=args.chat_template_content_format,
1789+
trust_request_chat_template=args.trust_request_chat_template,
17871790
log_error_stack=args.log_error_stack,
17881791
)
17891792
if "classify" in supported_tasks

vllm/entrypoints/openai/protocol.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
20002000
usage: UsageInfo
20012001

20022002

2003-
class ClassificationRequest(OpenAIBaseModel):
2003+
class ClassificationCompletionRequest(OpenAIBaseModel):
20042004
model: str | None = None
20052005
input: list[str] | str
2006-
truncate_prompt_tokens: int | None = None
2006+
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
20072007
user: str | None = None
20082008

20092009
# --8<-- [start:classification-extra-params]
@@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
20152015
"if the served model does not use priority scheduling."
20162016
),
20172017
)
2018-
2018+
add_special_tokens: bool = Field(
2019+
default=True,
2020+
description=(
2021+
"If true (the default), special tokens (e.g. BOS) will be added to "
2022+
"the prompt."
2023+
),
2024+
)
2025+
request_id: str = Field(
2026+
default_factory=lambda: f"{random_uuid()}",
2027+
description=(
2028+
"The request_id related to this request. If the caller does "
2029+
"not set it, a random_uuid will be generated. This id is used "
2030+
"through out the inference process and return in response."
2031+
),
2032+
)
20192033
softmax: bool | None = Field(
20202034
default=None,
20212035
description="softmax will be deprecated, please use use_activation instead.",
@@ -2040,6 +2054,102 @@ def to_pooling_params(self):
20402054
)
20412055

20422056

2057+
class ClassificationChatRequest(OpenAIBaseModel):
2058+
model: str | None = None
2059+
messages: list[ChatCompletionMessageParam]
2060+
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
2061+
user: str | None = None
2062+
2063+
# --8<-- [start:chat-classification-extra-params]
2064+
add_generation_prompt: bool = Field(
2065+
default=False,
2066+
description=(
2067+
"If true, the generation prompt will be added to the chat template. "
2068+
"This is a parameter used by chat template in tokenizer config of the "
2069+
"model."
2070+
),
2071+
)
2072+
2073+
add_special_tokens: bool = Field(
2074+
default=False,
2075+
description=(
2076+
"If true, special tokens (e.g. BOS) will be added to the prompt "
2077+
"on top of what is added by the chat template. "
2078+
"For most models, the chat template takes care of adding the "
2079+
"special tokens so this should be set to false (as is the "
2080+
"default)."
2081+
),
2082+
)
2083+
2084+
chat_template: str | None = Field(
2085+
default=None,
2086+
description=(
2087+
"A Jinja template to use for this conversion. "
2088+
"As of transformers v4.44, default chat template is no longer "
2089+
"allowed, so you must provide a chat template if the tokenizer "
2090+
"does not define one."
2091+
),
2092+
)
2093+
2094+
chat_template_kwargs: dict[str, Any] | None = Field(
2095+
default=None,
2096+
description=(
2097+
"Additional keyword args to pass to the template renderer. "
2098+
"Will be accessible by the chat template."
2099+
),
2100+
)
2101+
2102+
mm_processor_kwargs: dict[str, Any] | None = Field(
2103+
default=None,
2104+
description=("Additional kwargs to pass to the HF processor."),
2105+
)
2106+
2107+
priority: int = Field(
2108+
default=0,
2109+
description=(
2110+
"The priority of the request (lower means earlier handling; "
2111+
"default: 0). Any priority other than 0 will raise an error "
2112+
"if the served model does not use priority scheduling."
2113+
),
2114+
)
2115+
2116+
request_id: str = Field(
2117+
default_factory=lambda: f"{random_uuid()}",
2118+
description=(
2119+
"The request_id related to this request. If the caller does "
2120+
"not set it, a random_uuid will be generated. This id is used "
2121+
"through out the inference process and return in response."
2122+
),
2123+
)
2124+
softmax: bool | None = Field(
2125+
default=None,
2126+
description="softmax will be deprecated, please use use_activation instead.",
2127+
)
2128+
2129+
activation: bool | None = Field(
2130+
default=None,
2131+
description="activation will be deprecated, please use use_activation instead.",
2132+
)
2133+
2134+
use_activation: bool | None = Field(
2135+
default=None,
2136+
description="Whether to use activation for classification outputs. "
2137+
"Default is True.",
2138+
)
2139+
# --8<-- [end:chat-classification-extra-params]
2140+
2141+
def to_pooling_params(self):
2142+
return PoolingParams(
2143+
truncate_prompt_tokens=self.truncate_prompt_tokens,
2144+
use_activation=get_use_activation(self),
2145+
)
2146+
2147+
2148+
ClassificationRequest: TypeAlias = (
2149+
ClassificationCompletionRequest | ClassificationChatRequest
2150+
)
2151+
2152+
20432153
class ClassificationData(OpenAIBaseModel):
20442154
index: int
20452155
label: str | None

0 commit comments

Comments
 (0)