diff --git a/omlx/scheduler.py b/omlx/scheduler.py index d3f9ca1dd..fab20ed7b 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -5092,50 +5092,50 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: ) intermediate_snapshots = None - # Boundary merge + async_eval on the inference - # thread. async_eval dispatches KV array - # materialization without blocking, so the - # inference thread can start the next request - # immediately. The worker calls - # mx.synchronize() to wait for completion - # before extracting bytes. - with ( - self._phase_timer("store_cache_main_prep"), - mx.stream(generation_stream), - ): - boundary_override = self._get_boundary_store_override( - request_id, - cacheable_sequence, - ) - if boundary_override is not None: - ( - token_sequence_to_store, - boundary_cache, - boundary_model_config, - intermediate_snapshots, - ) = boundary_override - cache_to_store = ( - self._merge_boundary_with_full_cache( - boundary_cache, request._extracted_cache - ) - ) - if boundary_model_config is not None: - model_cache_config = boundary_model_config - logger.info( - f"Using boundary cache snapshot for {request_id}: " - f"storing {len(token_sequence_to_store)}/" - f"{len(full_token_sequence)} tokens " - f"(skipping trailing partial block, " - f"{len(intermediate_snapshots) if intermediate_snapshots else 0} " - f"intermediate snapshots)" + # Inference-thread store_cache prep, timed as + # three sub-phases (boundary / collect / dispatch) + # mirroring boundary_capture_* granularity. + # async_eval dispatches KV array materialization + # without blocking; the worker calls + # mx.synchronize() to wait before extracting + # bytes. + with mx.stream(generation_stream): + with self._phase_timer("store_cache_main_boundary"): + boundary_override = self._get_boundary_store_override( + request_id, + cacheable_sequence, ) - pre_eval_arrays = ( - self._collect_arrays_from_extracted_cache( - cache_to_store + if boundary_override is not None: + ( + token_sequence_to_store, + boundary_cache, + boundary_model_config, + intermediate_snapshots, + ) = boundary_override + cache_to_store = ( + self._merge_boundary_with_full_cache( + boundary_cache, request._extracted_cache + ) + ) + if boundary_model_config is not None: + model_cache_config = boundary_model_config + logger.info( + f"Using boundary cache snapshot for {request_id}: " + f"storing {len(token_sequence_to_store)}/" + f"{len(full_token_sequence)} tokens " + f"(skipping trailing partial block, " + f"{len(intermediate_snapshots) if intermediate_snapshots else 0} " + f"intermediate snapshots)" + ) + with self._phase_timer("store_cache_main_collect"): + pre_eval_arrays = ( + self._collect_arrays_from_extracted_cache( + cache_to_store + ) ) - ) - if pre_eval_arrays: - mx.async_eval(*pre_eval_arrays) + with self._phase_timer("store_cache_main_dispatch"): + if pre_eval_arrays: + mx.async_eval(*pre_eval_arrays) if self._store_cache_executor is not None: store_future = self._store_cache_executor.submit(