diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 08dbb4c45039..0117b8a0270a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1259,7 +1259,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # No Embedding Models so far. - if model_config.task not in ["generate"]: + if model_config.task not in ["generate", "embed", "classify", "score", "reward"]: _raise_or_fallback(feature_name=f"--task {model_config.task}", recommend_to_remove=False) return False diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a04ab885a72b..5993edf6adc4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -241,6 +241,7 @@ def __init__( **kwargs, ) + logger.info(f"Engine args: {engine_args}") # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 928fb231a1f2..a94d8a659a95 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -13,6 +13,7 @@ KVConnectorMetadata) from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request @@ -26,6 +27,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams + pooling_params: PoolingParams block_ids: list[int] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -43,6 +45,7 @@ def from_request( mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, + pooling_params=request.pooling_params, block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 05472ea573d3..3888d3b5a912 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -646,6 +646,7 @@ def update_from_output( spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + hidden_states = model_runner_output.hidden_states num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: list[Request] = [] @@ -663,6 +664,21 @@ def update_from_output( new_running.append(request) continue + if hidden_states is not None: + request.status = RequestStatus.FINISHED_STOPPED + self._free_request(request) + outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=[], + finish_reason=request.get_finished_reason(), + new_logprobs=None, + new_prompt_logprobs_tensors=None, + stop_reason=request.stop_reason, + events=request.take_events(), + hidden_states=hidden_states)) + continue + req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e33d1a1e5dcd..0d43e2da99eb 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -6,10 +6,12 @@ from typing import Any, Optional, Union import msgspec +import torch from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -53,6 +55,7 @@ class EngineCoreRequest( mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] + pooling_params: Optional[PoolingParams] sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float @@ -106,6 +109,8 @@ class EngineCoreOutput( stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + hidden_states: Optional[torch.Tensor] = None + @property def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 14ce820cc39e..8eeac89e6689 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -19,7 +19,7 @@ from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -212,11 +212,11 @@ async def add_request( if self.errored: raise EngineDeadError() - assert isinstance(params, SamplingParams), \ - "Pooling is not supported in V1" + # assert isinstance(params, SamplingParams), \ + # "Pooling is not supported in V1" # Create a new output collector for the request. - queue = RequestOutputCollector(output_kind=params.output_kind) + queue = RequestOutputCollector(output_kind=RequestOutputKind.CUMULATIVE) # Convert Input --> Request. prompt_str, request = self.processor.process_inputs( @@ -224,7 +224,7 @@ async def add_request( tokenization_kwargs, trace_headers, prompt_adapter_request, priority) - if params.n == 1: + if isinstance(params, PoolingParams) or params.n == 1: await self._add_request(request, prompt_str, None, 0, queue) return queue @@ -425,7 +425,7 @@ def _record_stats( stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) - def encode( + async def encode( self, prompt: PromptType, pooling_params: PoolingParams, @@ -434,7 +434,61 @@ def encode( trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ): - raise ValueError("Not Supported on V1 yet.") + try: + # We start the output_handler on the first call to generate() so + # we can call __init__ before the event loop, which enables us + # to handle startup failure gracefully in the OpenAI server. + self._run_output_handler() + + q = await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=None, + priority=priority, + ) + + # The output_handler task pushes items into the queue. + # This task pulls from the queue and yields to caller. + finished = False + while not finished: + # Note: drain queue without await if possible (avoids + # task switching under load which helps performance). + out = q.get_nowait() or await q.get() + + # Note: both OutputProcessor and EngineCore handle their + # own request cleanup based on finished. + finished = out.finished + yield out + + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. + except asyncio.CancelledError: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) + raise + + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed (engine dead).", request_id) + raise + + # Request validation error. + except ValueError: + if self.log_requests: + logger.info("Request %s failed (bad request).", request_id) + raise + + # Unexpected error in the generate() task (possibly recoverable). + except Exception as e: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e async def get_vllm_config(self) -> VllmConfig: return self.vllm_config diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e772615b7861..c092e303d63a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -54,7 +54,7 @@ def __init__(self, executor_class: type[Executor], log_stats: bool, executor_fail_callback: Optional[Callable] = None): - assert vllm_config.model_config.runner_type != "pooling" + # assert vllm_config.model_config.runner_type != "pooling" logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5f5ffe6e09db..b542ad987c7f 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Optional, Union -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -28,20 +28,20 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, Exception]] = None + self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = None self.ready = asyncio.Event() - def put(self, output: Union[RequestOutput, Exception]) -> None: + def put(self, output: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif isinstance(self.output, RequestOutput): + elif isinstance(self.output, RequestOutput) or isinstance(self.output, PoolingRequestOutput): # This ensures that request outputs with different request indexes # (if n > 1) do not override each other. self.output.add(output, aggregate=self.aggregate) - async def get(self) -> RequestOutput: + async def get(self) -> Union[RequestOutput, PoolingRequestOutput]: """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() @@ -51,7 +51,7 @@ async def get(self) -> RequestOutput: raise output return output - def get_nowait(self) -> Optional[RequestOutput]: + def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: """Non-blocking get operation.""" output = self.output if output is not None: @@ -65,7 +65,7 @@ def get_nowait(self) -> Optional[RequestOutput]: @dataclass class OutputProcessorOutput: - request_outputs: list[RequestOutput] + request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] @@ -80,8 +80,8 @@ def __init__( output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: list[int], - logprobs_processor: LogprobsProcessor, - detokenizer: IncrementalDetokenizer, + logprobs_processor: Optional[LogprobsProcessor], + detokenizer: Optional[IncrementalDetokenizer], max_tokens_param: Optional[int], arrival_time: float, queue: Optional[RequestOutputCollector], @@ -115,7 +115,7 @@ def from_new_request( queue: Optional[RequestOutputCollector], log_stats: bool, ) -> "RequestState": - if not request.sampling_params.detokenize: + if not request.sampling_params or not request.sampling_params.detokenize: tokenizer = None return cls( request_id=request.request_id, @@ -123,17 +123,11 @@ def from_new_request( request_index=request_index, lora_name=(request.lora_request.name if request.lora_request is not None else None), - output_kind=request.sampling_params.output_kind, + output_kind=RequestOutputKind.CUMULATIVE, prompt=prompt, prompt_token_ids=request.prompt_token_ids, - logprobs_processor=LogprobsProcessor.from_new_request( - tokenizer=tokenizer, - request=request, - ), - detokenizer=IncrementalDetokenizer.from_new_request( - tokenizer=tokenizer, - request=request, - ), + logprobs_processor=None, + detokenizer=None, max_tokens_param=(request.sampling_params.max_tokens if request.sampling_params is not None else None), arrival_time=request.arrival_time, @@ -146,7 +140,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], - ) -> Optional[RequestOutput]: + ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -175,12 +169,13 @@ def _new_request_output( outputs: list[CompletionOutput], finished: bool, ) -> RequestOutput: - - if self.output_kind == RequestOutputKind.DELTA: - # Side effect: logprobs processor forgets prompt logprobs - prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() - else: - prompt_logprobs = self.logprobs_processor.prompt_logprobs + prompt_logprobs = None + if self.logprobs_processor is not None: + if self.output_kind == RequestOutputKind.DELTA: + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = self.logprobs_processor.prompt_logprobs return RequestOutput( request_id=request_id, @@ -201,22 +196,27 @@ def _new_completion_output( finished = finish_reason is not None delta = self.output_kind == RequestOutputKind.DELTA - # Prepare text and token_ids, based on delta mode - text = self.detokenizer.get_next_output_text(finished, delta) - if not delta: - token_ids = self.detokenizer.output_token_ids + text = "" + if self.detokenizer is not None: + # Prepare text and token_ids, based on delta mode + text = self.detokenizer.get_next_output_text(finished, delta) + if not delta: + token_ids = self.detokenizer.output_token_ids - # Prepare logprobs, based on delta mode - logprobs = self.logprobs_processor.logprobs - if delta and logprobs: - logprobs = logprobs[-len(token_ids):] + logprobs = None + if self.logprobs_processor is not None: + # Prepare logprobs, based on delta mode + logprobs = self.logprobs_processor.logprobs + if delta and logprobs: + logprobs = logprobs[-len(token_ids):] return CompletionOutput( index=self.request_index, text=text, token_ids=token_ids, logprobs=logprobs, - cumulative_logprob=self.logprobs_processor.cumulative_logprob, + cumulative_logprob=self.logprobs_processor.cumulative_logprob + if self.logprobs_processor else None, finish_reason=str(finish_reason) if finished else None, stop_reason=stop_reason if finished else None) @@ -318,7 +318,7 @@ def process_outputs( within the loop below. """ - request_outputs: list[RequestOutput] = [] + request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -335,28 +335,47 @@ def process_outputs( new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason + hidden_states = engine_core_output.hidden_states req_state.is_prefilling = False - # 2) Detokenize the token ids into text and perform stop checks. - stop_string = req_state.detokenizer.update( - new_token_ids, finish_reason == FinishReason.STOP) - if stop_string: - finish_reason = FinishReason.STOP - stop_reason = stop_string - - # 3) Compute sample and prompt logprobs for request, if required. - req_state.logprobs_processor.update_from_output(engine_core_output) - - # 4) Create and handle RequestOutput objects. - if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + if hidden_states is not None: + # Process pooling request output + request_output = PoolingRequestOutput( + request_id=req_id, + outputs=PoolingOutput(data=hidden_states), + prompt_token_ids=new_token_ids, + finished=True, + ) if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) else: # LLMEngine: return list of RequestOutputs. request_outputs.append(request_output) + else: + # 2) Detokenize the token ids into text and perform stop checks. + if req_state.detokenizer: + stop_string = req_state.detokenizer.update( + new_token_ids, finish_reason == FinishReason.STOP) + if stop_string: + finish_reason = FinishReason.STOP + stop_reason = stop_string + + # 3) Compute sample and prompt logprobs for request, if required. + if req_state.logprobs_processor: + req_state.logprobs_processor.update_from_output( + engine_core_output) + + # 4) Create and handle RequestOutput objects. + if request_output := req_state.make_request_output( + new_token_ids, finish_reason, stop_reason): + if req_state.queue is not None: + # AsyncLLM: put into queue for handling by generate(). + req_state.queue.put(request_output) + else: + # LLMEngine: return list of RequestOutputs. + request_outputs.append(request_output) # Free completed requests. if finish_reason is not None: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 27d70a781471..4e073dd98133 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -129,7 +129,8 @@ def _validate_params( """ if not isinstance(params, SamplingParams): - raise ValueError("V1 does not yet support Pooling models.") + return + # raise ValueError("V1 does not yet support Pooling models.") self._validate_logprobs(params) self._validate_sampling_params(params) @@ -246,18 +247,23 @@ def process_inputs( if encoder_inputs is not None: raise NotImplementedError - assert isinstance(params, SamplingParams) - # TODO: can we avoid cloning here in multiproc case? - sampling_params = params.clone() - # If unset max tokens, then generate up to the max_model_len. - if sampling_params.max_tokens is None: - sampling_params.max_tokens = ( - self.model_config.max_model_len - - len(decoder_inputs["prompt_token_ids"])) - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) - sampling_params.update_from_tokenizer( - self.tokenizer.get_lora_tokenizer(lora_request)) + # assert isinstance(params, SamplingParams) + pooling_params = None + sampling_params = SamplingParams.from_optional() + if isinstance(params, SamplingParams): + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + sampling_params.max_tokens = ( + self.model_config.max_model_len - + len(decoder_inputs["prompt_token_ids"])) + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) + elif isinstance(params, PoolingParams): + pooling_params = params.clone() # Multimodal related. sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None @@ -314,6 +320,7 @@ def process_inputs( mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, sampling_params=sampling_params, + pooling_params=pooling_params, eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a..eab9ec22eeac 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -100,6 +100,9 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + # Used for pooling + hidden_states: Optional[torch.Tensor] + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[], @@ -108,4 +111,5 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + hidden_states=None, ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fde366d61c7d..2cb83d235fca 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, @@ -30,6 +31,7 @@ def __init__( lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, + pooling_params: Optional[PoolingParams] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -72,6 +74,8 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + self.pooling_params = pooling_params + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.mm_inputs is not None: @@ -86,6 +90,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, + pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c00424dfea73..db080c86d731 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,6 +9,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors @@ -27,6 +28,7 @@ class CachedRequestState: mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams + pooling_params: PoolingParams generator: Optional[torch.Generator] block_ids: list[int] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..ad13da1236a9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -350,6 +350,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, + pooling_params=new_req_data.pooling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, @@ -1118,6 +1119,18 @@ def execute_model( # For mid-pipeline stages, return the hidden states. return hidden_states + if self.model_config.runner_type == "pooling": + hidden_states = torch.mean(hidden_states, dim=0).cpu() + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + hidden_states=hidden_states, + ) + sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1291,6 +1304,7 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + hidden_states=None, ) def generate_draft_token_ids( @@ -1524,6 +1538,7 @@ def _dummy_sampler_run( self, hidden_states: torch.Tensor, ) -> torch.Tensor: + return logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) @@ -1661,7 +1676,8 @@ def profile_run(self) -> None: hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: - sampler_output = self._dummy_sampler_run(hidden_states) + # sampler_output = self._dummy_sampler_run(hidden_states) + sampler_output = None else: sampler_output = None torch.cuda.synchronize() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d716542f7898..6f2aadf2ab60 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -359,6 +359,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, + pooling_params=new_req_data.pooling_params, generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens,