diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 671bb948780a..25080d4189c2 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -46,6 +46,16 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str assert hasattr(output.data[0], "probs") +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str): + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "input": "hello", "add_special_tokens": False}, + ) + response.raise_for_status() + ClassificationResponse.model_validate(response.json()) + + @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): input_texts = [ diff --git a/tests/entrypoints/pooling/openai/test_vision_classification.py b/tests/entrypoints/pooling/openai/test_vision_classification.py new file mode 100644 index 000000000000..f2616e057b17 --- /dev/null +++ b/tests/entrypoints/pooling/openai/test_vision_classification.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.openai.protocol import ClassificationResponse + +VLM_MODEL_NAME = "muziyongshixin/Qwen2.5-VL-7B-for-VideoCls" +MAXIMUM_VIDEOS = 1 +TEST_VIDEO_URL = "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4" + +HF_OVERRIDES = { + "text_config": { + "architectures": ["Qwen2_5_VLForSequenceClassification"], + }, +} + + +@pytest.fixture(scope="module") +def server_vlm_classify(): + args = [ + "--runner", + "pooling", + "--max-model-len", + "5000", + "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"video": MAXIMUM_VIDEOS}), + ] + + with RemoteOpenAIServer( + VLM_MODEL_NAME, args, override_hf_configs=HF_OVERRIDES + ) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME]) +def test_classify_accepts_chat_text_only( + server_vlm_classify: RemoteOpenAIServer, model_name: str +) -> None: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please classify this text request."}, + ], + } + ] + + response = requests.post( + server_vlm_classify.url_for("classify"), + json={"model": model_name, "messages": messages}, + ) + response.raise_for_status() + + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == model_name + assert len(output.data) == 1 + assert len(output.data[0].probs) == 2 + assert output.usage.prompt_tokens == 22 + + +@pytest.mark.parametrize("model_name", [VLM_MODEL_NAME]) +def test_classify_accepts_chat_video_url( + server_vlm_classify: RemoteOpenAIServer, model_name: str +) -> None: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please classify this video."}, + {"type": "video_url", "video_url": {"url": TEST_VIDEO_URL}}, + ], + } + ] + + response = requests.post( + server_vlm_classify.url_for("classify"), + json={"model": model_name, "messages": messages}, + ) + response.raise_for_status() + + output = ClassificationResponse.model_validate(response.json()) + + assert output.object == "list" + assert output.model == model_name + assert len(output.data) == 1 + assert len(output.data[0].probs) == 2 + assert output.usage.prompt_tokens == 4807 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fbb2d32a229d..f30c6ef2cd0a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1784,6 +1784,9 @@ async def init_app_state( engine_client, state.openai_serving_models, request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 69e757d4764d..45584df8b9e2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2000,10 +2000,10 @@ class ScoreResponse(OpenAIBaseModel): usage: UsageInfo -class ClassificationRequest(OpenAIBaseModel): +class ClassificationCompletionRequest(OpenAIBaseModel): model: str | None = None input: list[str] | str - truncate_prompt_tokens: int | None = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None user: str | None = None # --8<-- [start:classification-extra-params] @@ -2015,7 +2015,21 @@ class ClassificationRequest(OpenAIBaseModel): "if the served model does not use priority scheduling." ), ) - + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt." + ), + ) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) softmax: bool | None = Field( default=None, description="softmax will be deprecated, please use use_activation instead.", @@ -2040,6 +2054,102 @@ def to_pooling_params(self): ) +class ClassificationChatRequest(OpenAIBaseModel): + model: str | None = None + messages: list[ChatCompletionMessageParam] + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + user: str | None = None + + # --8<-- [start:chat-classification-extra-params] + add_generation_prompt: bool = Field( + default=False, + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), + ) + + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)." + ), + ) + + chat_template: str | None = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ), + ) + + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + mm_processor_kwargs: dict[str, Any] | None = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + softmax: bool | None = Field( + default=None, + description="softmax will be deprecated, please use use_activation instead.", + ) + + activation: bool | None = Field( + default=None, + description="activation will be deprecated, please use use_activation instead.", + ) + + use_activation: bool | None = Field( + default=None, + description="Whether to use activation for classification outputs. " + "Default is True.", + ) + # --8<-- [end:chat-classification-extra-params] + + def to_pooling_params(self): + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + use_activation=get_use_activation(self), + ) + + +ClassificationRequest: TypeAlias = ( + ClassificationCompletionRequest | ClassificationChatRequest +) + + class ClassificationData(OpenAIBaseModel): index: int label: str | None diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 45bbe732a680..167ee152fece 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -4,13 +4,17 @@ from http import HTTPStatus from typing import cast +import jinja2 import numpy as np from fastapi import Request -from typing_extensions import override from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ClassificationChatRequest, + ClassificationCompletionRequest, ClassificationData, ClassificationRequest, ClassificationResponse, @@ -32,7 +36,10 @@ class ClassificationMixin(OpenAIServing): - @override + chat_template: str | None + chat_template_content_format: ChatTemplateContentFormatOption + trust_request_chat_template: bool + async def _preprocess( self, ctx: ServeContext, @@ -42,31 +49,79 @@ async def _preprocess( and prepare model-specific inputs. """ ctx = cast(ClassificationServeContext, ctx) - if isinstance(ctx.request.input, str) and not ctx.request.input: - return self.create_error_response( - "Input cannot be empty for classification", - status_code=HTTPStatus.BAD_REQUEST, - ) - - if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0: - return None - try: ctx.tokenizer = await self.engine_client.get_tokenizer() - renderer = self._get_renderer(ctx.tokenizer) - ctx.engine_prompts = await renderer.render_prompt( - prompt_or_prompts=ctx.request.input, - config=self._build_render_config(ctx.request), - ) + request_obj = ctx.request + + if isinstance(request_obj, ClassificationChatRequest): + chat_request = request_obj + messages = chat_request.messages + trust_request_chat_template = getattr( + self, + "trust_request_chat_template", + False, + ) + ret = self._validate_chat_template( + request_chat_template=chat_request.chat_template, + chat_template_kwargs=chat_request.chat_template_kwargs, + trust_request_chat_template=trust_request_chat_template, + ) + if ret: + return ret + + ( + _, + _, + engine_prompts, + ) = await self._preprocess_chat( + cast(ChatCompletionRequest, chat_request), + ctx.tokenizer, + messages, + chat_template=( + chat_request.chat_template + or getattr(self, "chat_template", None) + ), + chat_template_content_format=cast( + ChatTemplateContentFormatOption, + getattr(self, "chat_template_content_format", "auto"), + ), + add_generation_prompt=False, + continue_final_message=False, + add_special_tokens=chat_request.add_special_tokens, + ) + ctx.engine_prompts = engine_prompts + + elif isinstance(request_obj, ClassificationCompletionRequest): + completion_request = request_obj + input_data = completion_request.input + if input_data in (None, ""): + return self.create_error_response( + "Input or messages must be provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + if isinstance(input_data, list) and not input_data: + ctx.engine_prompts = [] + return None + + renderer = self._get_renderer(ctx.tokenizer) + prompt_input = cast(str | list[str], input_data) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=prompt_input, + config=self._build_render_config(completion_request), + ) + else: + return self.create_error_response( + "Invalid classification request type", + status_code=HTTPStatus.BAD_REQUEST, + ) return None - except (ValueError, TypeError) as e: + except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - @override def _build_response( self, ctx: ServeContext, @@ -118,6 +173,7 @@ def _build_render_config(self, request: ClassificationRequest) -> RenderConfig: return RenderConfig( max_length=self.max_model_len, truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, ) @@ -130,6 +186,9 @@ def __init__( models: OpenAIServingModels, *, request_logger: RequestLogger | None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__( @@ -139,6 +198,10 @@ def __init__( log_error_stack=log_error_stack, ) + self.chat_template = chat_template + self.chat_template_content_format = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template + async def create_classify( self, request: ClassificationRequest, @@ -156,7 +219,6 @@ async def create_classify( return await super().handle(ctx) # type: ignore - @override def _create_pooling_params( self, ctx: ClassificationServeContext, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1456727a3cdd..03f10e5a91e6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -43,6 +43,8 @@ ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, + ClassificationChatRequest, + ClassificationCompletionRequest, ClassificationRequest, ClassificationResponse, CompletionRequest, @@ -114,13 +116,16 @@ | DetokenizeRequest | EmbeddingCompletionRequest | RerankRequest - | ClassificationRequest + | ClassificationCompletionRequest | ScoreRequest | TokenizeCompletionRequest ) ChatLikeRequest: TypeAlias = ( - ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest + ChatCompletionRequest + | EmbeddingChatRequest + | TokenizeChatRequest + | ClassificationChatRequest ) SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest AnyRequest: TypeAlias = ( @@ -814,7 +819,11 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: if not hasattr(request, "messages"): return message_types - for message in request.messages: + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types + + for message in messages: if ( isinstance(message, dict) and "content" in message @@ -907,7 +916,8 @@ def _validate_input( EmbeddingCompletionRequest, ScoreRequest, RerankRequest, - ClassificationRequest, + ClassificationCompletionRequest, + ClassificationChatRequest, ), ): # Note: input length can be up to the entire model context length @@ -915,7 +925,8 @@ def _validate_input( if token_num > self.max_model_len: operations: dict[type[AnyRequest], str] = { ScoreRequest: "score", - ClassificationRequest: "classification", + ClassificationCompletionRequest: "classification", + ClassificationChatRequest: "classification", } operation = operations.get(type(request), "embedding generation") raise ValueError(