Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9ac8d5f
modified API servers
WorldExplored Oct 25, 2025
1b61ed7
refactors to simplify code
WorldExplored Oct 25, 2025
8e6762c
fixed conditional logic
WorldExplored Oct 25, 2025
58e2c45
Update vllm/entrypoints/openai/serving_classification.py
WorldExplored Oct 25, 2025
bfd1ea9
fixed pre-commit
WorldExplored Oct 26, 2025
301c491
addressed reviewer concerns
WorldExplored Oct 26, 2025
0c4d457
Merge branch 'main' into APIfix
WorldExplored Oct 26, 2025
97344fb
Merge branch 'main' into APIfix
WorldExplored Oct 28, 2025
12e08e0
addressed reviewer comments
WorldExplored Oct 28, 2025
dcdd9bd
fix: restore and correctly position aiosignal dependency comment
WorldExplored Oct 28, 2025
7791d2b
addressed reviews
WorldExplored Oct 28, 2025
21c8331
Merge branch 'main' into APIfix
WorldExplored Oct 29, 2025
a5041b5
addressed reviewer comments
WorldExplored Oct 29, 2025
58aae95
Merge branch 'main' into APIfix
WorldExplored Oct 29, 2025
ffdf73d
Addressed Reviewer Feedback
WorldExplored Oct 29, 2025
6d72821
Merge branch 'main' into APIfix
vnadathur Oct 31, 2025
4bab920
Merge branch 'main' into APIfix
WorldExplored Nov 1, 2025
abc9661
fixed pre-commit
WorldExplored Nov 1, 2025
8dadd78
addressed reviewer concerns
WorldExplored Nov 2, 2025
51f45af
Merge branch 'main' into APIfix
noooop Nov 10, 2025
4175a0a
added conversion trick to test
vnadathur Nov 11, 2025
ebcfdf7
Merge branch 'main' into APIfix
noooop Nov 13, 2025
e7d8411
addressed ci check
WorldExplored Nov 13, 2025
2343460
pre commit fix
WorldExplored Nov 14, 2025
25d1c79
fix tests
noooop Nov 14, 2025
396b057
Merge branch 'main' into APIfix
noooop Nov 14, 2025
2b810d8
input cannot be None
noooop Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/entrypoints/pooling/openai/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
assert hasattr(output.data[0], "probs")


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": "hello", "add_special_tokens": False},
)
response.raise_for_status()
ClassificationResponse.model_validate(response.json())


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
input_texts = [
Expand Down
95 changes: 95 additions & 0 deletions tests/entrypoints/pooling/openai/test_vision_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import pytest
import requests

from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse

VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls"
MAXIMUM_VIDEOS = 1
TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4"

HF_OVERRIDES = {
"text_config": {
"architectures": ["Qwen2_5_VLForSequenceClassification"],
},
}


@pytest.fixture(scope="module")
def server_vlm_classify():
args = [
"--runner",
"pooling",
"--max-model-len",
"5000",
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"video": MAXIMUM_VIDEOS}),
]

with RemoteOpenAIServer(
VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES
) as remote_server:
yield remote_server


@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_text_only(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this text request."},
],
}
]

response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()

output = ClassificationResponse.model_validate(response.json())

assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 22


@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME])
def test_classify_accepts_chat_video_url(
server_vlm_classify: RemoteOpenAIServer, model_name: str
) -> None:
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please classify this video."},
{"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}},
],
}
]

response = requests.post(
server_vlm_classify.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()

output = ClassificationResponse.model_validate(response.json())

assert output.object == "list"
assert output.model == model_name
assert len(output.data) == 1
assert len(output.data[0].probs) == 2
assert output.usage.prompt_tokens == 4807
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,9 @@ async def init_app_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "classify" in supported_tasks
Expand Down
116 changes: 113 additions & 3 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel):
usage: UsageInfo


class ClassificationRequest(OpenAIBaseModel):
class ClassificationCompletionRequest(OpenAIBaseModel):
model: str | None = None
input: list[str] | str
truncate_prompt_tokens: int | None = None
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None

# --8<-- [start:classification-extra-params]
Expand All @@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel):
"if the served model does not use priority scheduling."
),
)

add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
Expand All @@ -2040,6 +2054,102 @@ def to_pooling_params(self):
)


class ClassificationChatRequest(OpenAIBaseModel):
model: str | None = None
messages: list[ChatCompletionMessageParam]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
user: str | None = None

# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)

add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)

chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)

chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)

mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)

priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)

request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
softmax: bool | None = Field(
default=None,
description="softmax will be deprecated, please use use_activation instead.",
)

activation: bool | None = Field(
default=None,
description="activation will be deprecated, please use use_activation instead.",
)

use_activation: bool | None = Field(
default=None,
description="Whether to use activation for classification outputs. "
"Default is True.",
)
# --8<-- [end:chat-classification-extra-params]

def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=get_use_activation(self),
)


ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest
)


class ClassificationData(OpenAIBaseModel):
index: int
label: str | None
Expand Down
Loading