Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5092,50 +5092,50 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
)
intermediate_snapshots = None

# Boundary merge + async_eval on the inference
# thread. async_eval dispatches KV array
# materialization without blocking, so the
# inference thread can start the next request
# immediately. The worker calls
# mx.synchronize() to wait for completion
# before extracting bytes.
with (
self._phase_timer("store_cache_main_prep"),
mx.stream(generation_stream),
):
boundary_override = self._get_boundary_store_override(
request_id,
cacheable_sequence,
)
if boundary_override is not None:
(
token_sequence_to_store,
boundary_cache,
boundary_model_config,
intermediate_snapshots,
) = boundary_override
cache_to_store = (
self._merge_boundary_with_full_cache(
boundary_cache, request._extracted_cache
)
)
if boundary_model_config is not None:
model_cache_config = boundary_model_config
logger.info(
f"Using boundary cache snapshot for {request_id}: "
f"storing {len(token_sequence_to_store)}/"
f"{len(full_token_sequence)} tokens "
f"(skipping trailing partial block, "
f"{len(intermediate_snapshots) if intermediate_snapshots else 0} "
f"intermediate snapshots)"
# Inference-thread store_cache prep, timed as
# three sub-phases (boundary / collect / dispatch)
# mirroring boundary_capture_* granularity.
# async_eval dispatches KV array materialization
# without blocking; the worker calls
# mx.synchronize() to wait before extracting
# bytes.
with mx.stream(generation_stream):
with self._phase_timer("store_cache_main_boundary"):
boundary_override = self._get_boundary_store_override(
request_id,
cacheable_sequence,
)
pre_eval_arrays = (
self._collect_arrays_from_extracted_cache(
cache_to_store
if boundary_override is not None:
(
token_sequence_to_store,
boundary_cache,
boundary_model_config,
intermediate_snapshots,
) = boundary_override
cache_to_store = (
self._merge_boundary_with_full_cache(
boundary_cache, request._extracted_cache
)
)
if boundary_model_config is not None:
model_cache_config = boundary_model_config
logger.info(
f"Using boundary cache snapshot for {request_id}: "
f"storing {len(token_sequence_to_store)}/"
f"{len(full_token_sequence)} tokens "
f"(skipping trailing partial block, "
f"{len(intermediate_snapshots) if intermediate_snapshots else 0} "
f"intermediate snapshots)"
)
with self._phase_timer("store_cache_main_collect"):
pre_eval_arrays = (
self._collect_arrays_from_extracted_cache(
cache_to_store
)
)
)
if pre_eval_arrays:
mx.async_eval(*pre_eval_arrays)
with self._phase_timer("store_cache_main_dispatch"):
if pre_eval_arrays:
mx.async_eval(*pre_eval_arrays)

if self._store_cache_executor is not None:
store_future = self._store_cache_executor.submit(
Expand Down