From acafb6c0f2b10ab2942d9595cbfe329d4849be54 Mon Sep 17 00:00:00 2001 From: CHEN <116010019@link.cuhk.edu.cn> Date: Tue, 25 Nov 2025 21:18:07 +0800 Subject: [PATCH] recompute scheduler adapt main Signed-off-by: CHEN <116010019@link.cuhk.edu.cn> --- vllm_ascend/core/recompute_scheduler.py | 1080 ++++++++++++++--------- 1 file changed, 686 insertions(+), 394 deletions(-) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index d04f8f85500..e7f53337371 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -14,58 +14,64 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from __future__ import annotations - import itertools import time from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any -import numpy as np -import numpy.typing as npt +from vllm import envs from vllm.config import VllmConfig +from vllm.distributed.ec_transfer.ec_connector.base import ( + ECConnectorMetadata, + ECConnectorRole, +) +from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import \ - KVConnectorFactory -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ - KVConnectorStats -from vllm.logger import logger +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, + SupportsHMA, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import ( + CachedRequestData, + GrammarOutput, + NewRequestData, + SchedulerOutput, +) +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs, FinishReason) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs, FinishReason from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.utils import ConstantList +from vllm.v1.utils import record_function_or_nullcontext, ConstantList +logger = init_logger(__name__) -class RecomputeScheduler(SchedulerInterface): - """This Scheduler extends vllm's original v1 scheduler of version 0.11 - to fix recomputing bug.""" +class Scheduler(SchedulerInterface): def __init__( self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, - block_size: Optional[int] = None, + block_size: int, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, @@ -85,61 +91,63 @@ def __init__( # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + self.finished_req_ids_dict: dict[int, set[str]] | None = ( + defaultdict(set) if include_finished_set else None + ) + self.prev_step_scheduled_req_ids: set[str] = set() # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens + self.max_model_len = vllm_config.model_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None + self.connector_prefix_cache_stats: PrefixCacheStats | None = None if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors") assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, + role=KVConnectorRole.SCHEDULER, + kv_cache_config=self.kv_cache_config, + ) + if self.log_stats: + self.connector_prefix_cache_stats = PrefixCacheStats() self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, self.parallel_config.data_parallel_rank, ) + self.ec_connector = None + if self.vllm_config.ec_transfer_config is not None: + self.ec_connector = ECConnectorFactory.create_connector( + config=self.vllm_config, role=ECConnectorRole.SCHEDULER + ) num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 - self.block_size = self.cache_config.block_size - - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size - # Note(hc): The scheduler’s block_size must be multiplied - # by dcp_world_size, since block hashes are computed on the - # original full token sequence at a granularity of - # original_block_size × dcp_world_size. - if self.dcp_world_size > 1: - self.block_size *= self.dcp_world_size + self.block_size = block_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # req_id -> Request self.requests: dict[str, Request] = {} # Scheduling policy - if self.scheduler_config.policy == "priority": - self.policy = SchedulingPolicy.PRIORITY - elif self.scheduler_config.policy == "fcfs": - self.policy = SchedulingPolicy.FCFS - else: + try: + self.policy = SchedulingPolicy(self.scheduler_config.policy) + except ValueError as e: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) from e # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -152,6 +160,7 @@ def __init__( # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() + self.failed_recving_kv_req_ids: set[str] = set() # Encoder-related. # Calculate encoder cache size if applicable @@ -171,8 +180,7 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -192,21 +200,27 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 - - def schedule(self) -> RecomputeSchedulerOutput: - """This scheduler extends vLLM's original v1 scheduler - by introducing a decoding instance recomputing scheduling strategy. - Specifically, if a request is preempted in the decoding instance, - it halts the process with the recomputed symbol and recalculates - its KVC in the prefill instance.""" + self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + + def schedule(self) -> SchedulerOutput: + # NOTE(woosuk) on the scheduling algorithm: + # There's no "decoding phase" nor "prefill phase" in the scheduler. + # Each request just has the num_computed_tokens and + # num_tokens_with_spec. num_tokens_with_spec = + # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids). + # At each step, the scheduler tries to assign tokens to the requests + # so that each request's num_computed_tokens can catch up its + # num_tokens_with_spec. This is general enough to cover + # chunked prefills, prefix caching, speculative decoding, + # and the "jump decoding" optimization in the future. scheduled_new_reqs: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] - recomputed_reqs: list[RecomputeReqInfo] = [] req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} @@ -220,35 +234,48 @@ def schedule(self) -> RecomputeSchedulerOutput: # For logging. scheduled_timestamp = time.monotonic() + recomputed_reqs: list[RecomputeReqInfo] = [] + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. + # Make sure the input position does not exceed the max model len or + # request's max_tokens. + # This is necessary when using spec decoding and/or async scheduling. + max_total_tokens = min( + request.num_prompt_tokens + request.max_tokens, self.max_model_len + ) num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None + external_load_encoder_input: list[int] = [] new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -266,12 +293,18 @@ def schedule(self) -> RecomputeSchedulerOutput: req_index += 1 continue - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) - if new_blocks is None: + # Schedule newly needed KV blocks for the request. + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, + ) + + if new_blocks is not None: + # The request can be scheduled. + break transfer_config = self.vllm_config.kv_transfer_config if transfer_config is not None and not transfer_config.is_kv_producer: recomputed_req = self.running.pop() @@ -281,7 +314,6 @@ def schedule(self) -> RecomputeSchedulerOutput: recomputed_req.output_token_ids, recomputed_req.client_index)) if recomputed_req == request: - can_schedule = False break else: # The request cannot be scheduled. @@ -294,6 +326,26 @@ def schedule(self) -> RecomputeSchedulerOutput: self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + scheduled_spec_decode_tokens.pop( + preempted_req.request_id, None + ) + preempted_encoder_inputs = scheduled_encoder_inputs.pop( + preempted_req.request_id, None + ) + if preempted_encoder_inputs: + # Restore encoder compute budget if the preempted + # request had encoder inputs scheduled in this step. + num_tokens_to_restore = sum( + preempted_req.get_num_encoder_tokens(i) + for i in preempted_encoder_inputs + ) + encoder_compute_budget += num_tokens_to_restore + req_index -= 1 else: preempted_req = self.running.pop() @@ -301,24 +353,21 @@ def schedule(self) -> RecomputeSchedulerOutput: self.encoder_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 if self.log_stats: preempted_req.record_event( - EngineCoreEventType.PREEMPTED, - scheduled_timestamp) + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: - # No more request to preempt. - can_schedule = False + # No more request to preempt. Cannot schedule this request. break - else: - # The request can be scheduled. - can_schedule = True - break - if not can_schedule: + + if new_blocks is None: + # Cannot schedule this request. break - assert new_blocks is not None # Schedule the request. scheduled_running_reqs.append(request) @@ -329,30 +378,45 @@ def schedule(self) -> RecomputeSchedulerOutput: # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + + request.num_computed_tokens + - request.num_tokens + - request.num_output_placeholders + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) # Record the LoRAs in scheduled_running_reqs scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -375,7 +439,8 @@ def schedule(self) -> RecomputeSchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -393,9 +458,14 @@ def schedule(self) -> RecomputeSchedulerOutput: # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -407,17 +477,19 @@ def schedule(self) -> RecomputeSchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( + ext_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) - if num_external_computed_tokens is None: + if ext_tokens is None: # The request cannot be scheduled because # the KVConnector couldn't determine # the number of matched tokens. @@ -425,39 +497,44 @@ def schedule(self) -> RecomputeSchedulerOutput: skipped_waiting_requests.prepend_request(request) continue + request.num_external_computed_tokens = ext_tokens + num_external_computed_tokens = ext_tokens + # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) - # KVTransfer: WAITING reqs have num_computed_tokens > 0 - # after async KV recvs are completed. + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) else: - new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None + external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget - # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: + # KVTransfer: loading remote KV, do not allocate for new work. assert num_external_computed_tokens > 0 num_new_tokens = 0 - # Number of tokens to be scheduled. else: + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + threshold = self.scheduler_config.long_prefill_token_threshold + if 0 < threshold < num_new_tokens: + num_new_tokens = threshold # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.enable_chunked_prefill + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -467,11 +544,17 @@ def schedule(self) -> RecomputeSchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + external_load_encoder_input, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -481,9 +564,9 @@ def schedule(self) -> RecomputeSchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -491,8 +574,9 @@ def schedule(self) -> RecomputeSchedulerOutput: # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens = \ + num_encoder_tokens = ( self.scheduler_config.max_num_encoder_input_tokens + ) else: num_encoder_tokens = 0 @@ -531,23 +615,26 @@ def schedule(self) -> RecomputeSchedulerOutput: request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + self._update_connector_prefix_cache_stats(request) + req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -558,12 +645,18 @@ def schedule(self) -> RecomputeSchedulerOutput: # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget - + # Allocate for external load encoder cache + if external_load_encoder_input: + for i in external_load_encoder_input: + self.encoder_cache_manager.allocate(request, i) + if self.ec_connector is not None: + self.ec_connector.update_state_after_alloc(request, i) # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) @@ -571,42 +664,61 @@ def schedule(self) -> RecomputeSchedulerOutput: # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 assert len(self.running) <= self.max_num_running_reqs # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) + ) # Construct the scheduler output. - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) - for req in scheduled_new_reqs - ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + if self.use_v2_model_runner: + scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs + scheduled_resumed_reqs = [] + new_reqs_data = [ + NewRequestData.from_request( + req, + req_to_new_blocks[req.request_id].get_block_ids(), + req._all_token_ids, + ) + for req in scheduled_new_reqs + ] + else: + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids() + ) + for req in scheduled_new_reqs + ] + + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + + # Record the request ids that were scheduled in this step. + self.prev_step_scheduled_req_ids.clear() + self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + scheduler_output = RecomputeSchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -615,15 +727,13 @@ def schedule(self) -> RecomputeSchedulerOutput: scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, + preempted_req_ids={req.request_id for req in preempted_reqs}, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), - structured_output_request_ids=structured_output_request_ids, - grammar_bitmask=grammar_bitmask, + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), recomputed_reqs=recomputed_reqs, ) @@ -632,32 +742,25 @@ def schedule(self) -> RecomputeSchedulerOutput: # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) + meta: KVConnectorMetadata = self.connector.build_connector_meta( + scheduler_output + ) scheduler_output.kv_connector_metadata = meta - # collect KV cache events from KV cache manager - events = self.kv_cache_manager.take_events() - - # collect KV cache events from connector - if self.connector is not None: - connector_events = self.connector.take_events() - if connector_events: - if events is None: - events = list(connector_events) - else: - events.extend(connector_events) - - # publish collected KV cache events - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) + # Build the connector meta for ECConnector + if self.ec_connector is not None: + ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta( + scheduler_output + ) + scheduler_output.ec_connector_metadata = ec_meta - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output def _update_after_schedule( self, - scheduler_output: RecomputeSchedulerOutput, + scheduler_output: SchedulerOutput, ) -> None: # Advance the number of computed tokens for the request AFTER # the request is scheduled. @@ -696,43 +799,51 @@ def _make_cached_request_data( ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + new_block_ids: list[tuple[list[int], ...] | None] = [] + all_token_ids: dict[str, list[int]] = {} num_computed_tokens: list[int] = [] + num_output_tokens: list[int] = [] + resumed_req_ids = set() - use_connector = self.connector is not None - for req in itertools.chain(running_reqs, resumed_reqs): + num_running_reqs = len(running_reqs) + for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) - elif use_connector: - # When using a KVConnector, we add a placeholder to avoid index - # out of bounds errors. TODO: Remove this once the KVConnector - # is updated to handle token IDs properly. - new_token_ids.append([]) + scheduled_in_prev_step = req_id in self.prev_step_scheduled_req_ids + if idx >= num_running_reqs: + assert not scheduled_in_prev_step + resumed_req_ids.add(req_id) + if not scheduled_in_prev_step: + all_token_ids[req_id] = req.all_token_ids.copy() new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) + num_output_tokens.append( + req.num_output_tokens + req.num_output_placeholders + ) return CachedRequestData( req_ids=req_ids, - resumed_from_preemption=resumed_from_preemption, + resumed_req_ids=resumed_req_ids, new_token_ids=new_token_ids, + all_token_ids=all_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, + num_output_tokens=num_output_tokens, ) def _try_schedule_encoder_inputs( @@ -741,7 +852,7 @@ def _try_schedule_encoder_inputs( num_computed_tokens: int, num_new_tokens: int, encoder_compute_budget: int, - ) -> tuple[list[int], int, int]: + ) -> tuple[list[int], int, int, list[int]]: """ Determine which encoder inputs need to be scheduled in the current step, and update `num_new_tokens` and encoder token budget accordingly. @@ -751,6 +862,7 @@ def _try_schedule_encoder_inputs( in this step, i.e., [num_computed_tokens, num_computed_tokens + num_new_tokens). - It is not already computed and stored in the encoder cache. + - It is not exist on remote encoder cache (via ECConnector) - There is sufficient encoder token budget to process it. - The encoder cache has space to store it. @@ -762,12 +874,16 @@ def _try_schedule_encoder_inputs( blocks and externally cached blocks (via KVConnector). """ if num_new_tokens == 0 or not request.has_encoder_inputs: - return [], num_new_tokens, encoder_compute_budget + return [], num_new_tokens, encoder_compute_budget, [] encoder_inputs_to_schedule: list[int] = [] mm_features = request.mm_features assert mm_features is not None assert len(mm_features) > 0 + external_load_encoder_input = [] + # Check remote cache first + if self.ec_connector is not None: + remote_cache_has_item = self.ec_connector.has_caches(request) # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. @@ -787,7 +903,8 @@ def _try_schedule_encoder_inputs( if self.is_encoder_decoder and num_computed_tokens > 0: assert start_pos == 0, ( "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") + "the sequence when encoder-decoder models are used." + ) # Encoder input has already been computed # The calculation here is a bit different. We don't turn encoder # output into tokens that get processed by the decoder and @@ -811,8 +928,7 @@ def _try_schedule_encoder_inputs( # current step. continue - if self.encoder_cache_manager.check_and_update_cache( - request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. continue @@ -820,16 +936,18 @@ def _try_schedule_encoder_inputs( # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -846,6 +964,12 @@ def _try_schedule_encoder_inputs( num_new_tokens = 0 break + if self.ec_connector is not None and remote_cache_has_item[i]: + mm_hashes_to_schedule.add(request.mm_features[i].identifier) + external_load_encoder_input.append(i) + num_tokens_to_schedule += num_encoder_tokens + continue + num_tokens_to_schedule += num_encoder_tokens encoder_compute_budget -= num_encoder_tokens mm_hashes_to_schedule.add(request.mm_features[i].identifier) @@ -855,41 +979,37 @@ def _try_schedule_encoder_inputs( encoder_inputs_to_schedule, num_new_tokens, encoder_compute_budget, + external_load_encoder_input, ) def get_grammar_bitmask( self, - requests: list[Request], - scheduled_spec_decode_tokens: dict[str, list[int]], - ): - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to its index in the batch. - # This will help us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - for i, req in enumerate(requests): - if req.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[req.request_id] = i - + scheduler_output: SchedulerOutput, + ) -> GrammarOutput | None: + # Collect list of scheduled request ids that use structured output. + # The corresponding rows of the bitmask will be in this order. + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids = [ + req_id + for req_id in scheduler_output.num_scheduled_tokens + if (req := self.requests.get(req_id)) and req.use_structured_output + ] if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) - return structured_output_request_ids, bitmask + return None + + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduler_output.scheduled_spec_decode_tokens, + ) + return GrammarOutput(structured_output_request_ids, bitmask) def update_from_output( self, - scheduler_output: RecomputeSchedulerOutput, + scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids @@ -901,18 +1021,33 @@ def update_from_output( kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) - # return recomputed requests as EngineCoreOutput - for req_info in scheduler_output.recomputed_reqs: - outputs[req_info.client_index].append( - EngineCoreOutput( - request_id=req_info.request_id, - finish_reason=FinishReason.STOP, - new_token_ids=[req_info.output_token_ids[-1]], - stop_reason="recomputed", - )) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: KVConnectorStats | None = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) + if kv_connector_stats and self.connector: + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) + + failed_kv_load_req_ids = None + if kv_connector_output and kv_connector_output.invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + failed_kv_load_req_ids = self._handle_invalid_blocks( + kv_connector_output.invalid_block_ids + ) + if isinstance(scheduler_output, RecomputeSchedulerOutput): + # return recomputed requests as EngineCoreOutput + for req_info in scheduler_output.recomputed_reqs: + outputs[req_info.client_index].append( + EngineCoreOutput( + request_id=req_info.request_id, + finish_reason=FinishReason.STOP, + new_token_ids=[req_info.output_token_ids[-1]], + stop_reason="recomputed", + )) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best # to avoid expensive operations inside the loop. @@ -920,6 +1055,9 @@ def update_from_output( stopped_preempted_reqs: set[Request] = set() for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): assert num_tokens_scheduled > 0 + if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + continue request = self.requests.get(req_id) if request is None: # The request is already finished. This can happen if the @@ -928,11 +1066,13 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -942,11 +1082,17 @@ def update_from_output( # tokens and rejections. If some tokens are rejected, # num_computed_tokens is decreased by the number of rejected # tokens. - request.num_computed_tokens -= num_rejected + if request.num_computed_tokens > 0: + request.num_computed_tokens -= num_rejected + # If async scheduling, num_output_placeholders also includes + # the scheduled spec tokens count and so is similarly adjusted. + if request.num_output_placeholders > 0: + request.num_output_placeholders -= num_rejected spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -957,14 +1103,14 @@ def update_from_output( # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -974,28 +1120,27 @@ def update_from_output( stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + if new_token_ids and self.structured_output_manager.should_advance(request): + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -1010,7 +1155,9 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) + num_nans_in_logits=request.num_nans_in_logits, + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1023,9 +1170,25 @@ def update_from_output( self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. - if model_runner_output.kv_connector_output: - self._update_from_kv_xfer_finished( - model_runner_output.kv_connector_output) + if kv_connector_output: + self._update_from_kv_xfer_finished(kv_connector_output) + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. @@ -1044,11 +1207,13 @@ def update_from_output( eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1079,8 +1244,9 @@ def _update_request_with_output( return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -1095,21 +1261,19 @@ def _free_encoder_inputs(self, request: Request) -> None: # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -1117,13 +1281,11 @@ def update_draft_token_ids( continue # Add newly generated spec token ids to the request. - if not spec_token_ids: - # NOTE(woosuk): request.spec_token_ids should be updated. - request.spec_token_ids.clear() - elif self.structured_output_manager.should_advance(request): + if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1139,7 +1301,7 @@ def add_request(self, request: Request) -> None: def finish_requests( self, - request_ids: Union[str, Iterable[str]], + request_ids: str | Iterable[str], finished_status: RequestStatus, ) -> None: """Handles the finish signal from outside the scheduler. @@ -1149,7 +1311,7 @@ def finish_requests( """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1160,7 +1322,7 @@ def finish_requests( # First pass: collect requests to remove from queues for req_id in request_ids: request = self.requests.get(req_id) - if request is None: + if request is None or request.is_finished(): # Invalid request ID. continue @@ -1181,7 +1343,7 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + def _free_request(self, request: Request) -> dict[str, Any] | None: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) @@ -1212,36 +1374,37 @@ def reset_prefix_cache(self) -> bool: def make_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats] = None, - kv_connector_stats: Optional[KVConnectorStats] = None, - ) -> Optional[SchedulerStats]: + spec_decoding_stats: SpecDecodingStats | None = None, + kv_connector_stats: KVConnectorStats | None = None, + ) -> SchedulerStats | None: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + connector_prefix_cache_stats=connector_prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + ) def make_spec_decoding_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats], + spec_decoding_stats: SpecDecodingStats | None, num_draft_tokens: int, num_accepted_tokens: int, - ) -> Optional[SpecDecodingStats]: + ) -> SpecDecodingStats | None: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: @@ -1254,11 +1417,29 @@ def shutdown(self) -> None: # KV Connector Related Methods ######################################################################## - def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + def _update_connector_prefix_cache_stats(self, request: Request) -> None: + if self.connector_prefix_cache_stats is None: + return + + self.connector_prefix_cache_stats.record( + num_tokens=request.num_tokens, + num_hits=request.num_external_computed_tokens, + preempted=request.num_preemptions > 0, + ) + + def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None: + if self.connector_prefix_cache_stats is None: + return None + stats = self.connector_prefix_cache_stats + self.connector_prefix_cache_stats = PrefixCacheStats() + return stats + + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, dict[str, Any] | None]: """ Invoke the KV connector request_finished() method if applicable. @@ -1268,8 +1449,17 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - return self.connector.request_finished(request, block_ids) + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + + if not isinstance(self.connector, SupportsHMA): + # NOTE(Kuntai): We should deprecate this code path after we enforce + # all connectors to support HMA. + # Hybrid memory allocator should be already turned off for this + # code path, but let's double-check here. + assert len(self.kv_cache_config.kv_cache_groups) == 1 + return self.connector.request_finished(request, block_ids[0]) + + return self.connector.request_finished_all_groups(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ @@ -1287,25 +1477,37 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: if request.request_id not in self.finished_recving_kv_req_ids: return False - # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) - num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less than one block. - num_computed_tokens = min(num_computed_tokens, request.num_tokens) - if num_computed_tokens == request.num_tokens: - num_computed_tokens -= 1 - # This will cache the blocks iff caching is enabled. - self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + if request.request_id in self.failed_recving_kv_req_ids: + # Request had KV load failures; num_computed_tokens was already + # updated in _update_requests_with_invalid_blocks + if request.num_computed_tokens: + # Cache any valid computed tokens. + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) + else: + # No valid computed tokens, release allocated blocks. + # There may be a local cache hit on retry. + self.kv_cache_manager.free(request) - # Update the request state for scheduling. - request.num_computed_tokens = num_computed_tokens + self.failed_recving_kv_req_ids.remove(request.request_id) + else: + # Now that the blocks are ready, actually cache them. + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens # Return that we are ready. self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1320,18 +1522,154 @@ def _update_from_kv_xfer_finished(self, self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - if req_id not in self.requests: - logger.warning( - "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) + assert req_id in self.requests + self._free_blocks(self.requests[req_id]) + + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: + """ + Identify and update requests affected by invalid KV cache blocks. + + This method scans the given requests, detects those with invalid blocks + and adjusts their `num_computed_tokens` to the longest valid prefix. + For observability, it also accumulates the total number of tokens that + will need to be recomputed across all affected requests. + + Args: + requests: The set of requests to scan for invalid blocks. + invalid_block_ids: IDs of invalid blocks. + + Returns: + tuple: + - affected_req_ids (set[str]): IDs of requests impacted by + invalid blocks. + - total_affected_tokens (int): Total number of tokens that must + be recomputed across all affected requests (for observability). + """ + affected_req_ids: set[str] = set() + total_affected_tokens = 0 + # If a block is invalid and shared by multiple requests in the batch, + # these requests must be rescheduled, but only the first will recompute + # it. This set tracks blocks already marked for recomputation. + marked_invalid_block_ids: set[int] = set() + for request in requests: + is_affected = False + marked_invalid_block = False + req_id = request.request_id + # TODO (davidb): add support for hybrid memory allocator + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) + # We iterate only over blocks that may contain externally computed + # tokens + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + # Async loading. If num_computed_tokens is set it implies we + # already processed some block failures for it in a prior step + req_num_computed_tokens = ( + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) else: - self._free_blocks(self.requests[req_id]) + # Sync loading. num_computed_tokens includes new tokens + req_num_computed_tokens = request.num_cached_tokens + + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): + if block_id not in invalid_block_ids: + continue + + is_affected = True + + if block_id in marked_invalid_block_ids: + # This invalid block is shared with a previous request + # and was already marked for recomputation. + # This means this request can still consider this block + # as computed when rescheduled. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + continue + + marked_invalid_block_ids.add(block_id) + + if marked_invalid_block: + # This request has already marked an invalid block for + # recomputation and updated its num_computed_tokens. + continue + + marked_invalid_block = True + # Truncate the computed tokens at the first failed block + request.num_computed_tokens = idx * self.block_size + num_affected_tokens = ( + req_num_computed_tokens - request.num_computed_tokens + ) + total_affected_tokens += num_affected_tokens + request.num_external_computed_tokens -= num_affected_tokens + + if is_affected: + if not marked_invalid_block: + # All invalid blocks of this request are shared with + # previous requests and will be recomputed by them. + # Revert to considering only cached tokens as computed. + # Currently this only applies to sync loading; Async + # loading does not yet support block sharing + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) + request.num_computed_tokens = request.num_cached_tokens + + affected_req_ids.add(request.request_id) + + return affected_req_ids, total_affected_tokens + + def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: + total_requests_to_reschedule = 0 + total_tokens_to_reschedule = 0 + + # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- + async_load_reqs = ( + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) + async_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) + total_requests_to_reschedule += len(async_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + # Mark requests with async KV load failures; they will be rescheduled + # once loading completes. + self.failed_recving_kv_req_ids |= async_affected_req_ids + + # --- Handle sync KV loads (running requests) --- + sync_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) + + total_requests_to_reschedule += len(sync_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + if total_requests_to_reschedule: + logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) + + # Return the IDs of affected running requests to skip in + # update_from_output. + return sync_affected_req_ids @dataclass class RecomputeReqInfo: @@ -1339,53 +1677,7 @@ class RecomputeReqInfo: output_token_ids: ConstantList client_index: int = 0 - @dataclass -class RecomputeSchedulerOutput: - - # list of the requests that are scheduled for the first time. - # We cache the request's data in each worker process, so that we don't - # need to re-send it every scheduling step. - scheduled_new_reqs: list[NewRequestData] - # list of the requests that have been scheduled before. - # Since the request's data is already cached in the worker processes, - # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: CachedRequestData - - # req_id -> num_scheduled_tokens - # Number of tokens scheduled for each request. - num_scheduled_tokens: dict[str, int] - # Total number of tokens scheduled for all requests. - # Equal to sum(num_scheduled_tokens.values()) - total_num_scheduled_tokens: int - # req_id -> spec_token_ids - # If a request does not have any spec decode tokens, it will not be - # included in the dictionary. - scheduled_spec_decode_tokens: dict[str, list[int]] - # req_id -> encoder input indices that need processing. - # E.g., if a request has [0, 1], it could mean the vision encoder needs - # to process that the request's 0-th and 1-th images in the current step. - scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests in each KV cache group. - # This can be used for cascade attention. - num_common_prefix_blocks: list[int] - - # Request IDs that are finished in between the previous and the current - # steps. This is used to notify the workers about the finished requests - # so that they can free the cached states for those requests. - finished_req_ids: set[str] - # list of mm_hash strings associated with the encoder outputs to be - # freed from the encoder cache. - free_encoder_mm_hashes: list[str] - - # Dict of request ids to their index within the batch - # for filling the next token bitmask - structured_output_request_ids: dict[str, int] - # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] - +class RecomputeSchedulerOutput(SchedulerOutput): # requests that need to recompute kv - recomputed_reqs: list[RecomputeReqInfo] - - # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None + recomputed_reqs: list[RecomputeReqInfo] = None \ No newline at end of file