From ad775f002793cf34e2f4227af4763d6e3ad757e0 Mon Sep 17 00:00:00 2001 From: benjamin-levin Date: Tue, 19 May 2026 00:59:06 -0400 Subject: [PATCH] BatchGenerator: optional prefer_prefill_when_pending scheduler Adds an opt-in `prefer_prefill_when_pending` kwarg to `BatchGenerator` that pauses decode steps while any prefill work is queued or in flight. Default is `False`, so existing behavior is unchanged. Motivation ---------- On Apple M-series unified-memory GPUs, prefill and decode share a single Metal command engine. With multiple concurrent requests at long context, the existing scheduler interleaves one decode step (~10 ms) with one prefill chunk (~1 s+ at 32k context) per cycle. Each in-flight decoding request is therefore stalled almost the entire time another request is prefilling, dropping per-victim decode to ~0.7-1 tok/s and yielding ~0x aggregate scaling vs. a single request. Pausing decode until prefill has drained lets the batch decode together at native batched speed once prefill catches up. Measured impact (Qwen3.6-35B-A3B-4bit, M4 Max 36GB) --------------------------------------------------- Three concurrent requests at 32k context each: mode aggregate per-request scaling ----------------------------------------------- ------------ -------------- ------- prefer_prefill_when_pending=False (default) 86 tok/s ~0.8 victim 0x prefer_prefill_when_pending=True 163 tok/s symmetric 1.89x Trade-off --------- With the flag enabled, late-arriving requests' TTFT is unchanged but already-decoding requests "freeze" briefly while new prefill runs. For background-agent and chatbot-batch workloads where aggregate throughput dominates, this is the right trade. For single-stream low-latency it is not. Hence: opt-in, default off. Backwards compatibility ----------------------- Default value is `False`, so existing callers see no behavior change. All other `BatchGenerator` semantics, attributes, and method signatures are untouched. Tests ----- Three tests in `tests/test_generate.py`: - test_prefer_prefill_when_pending_default_false: locks in default behavior. - test_prefer_prefill_when_pending_accepted_and_stored: asserts the kwarg is accepted and stored on the instance. - test_prefer_prefill_pauses_decode_when_prefill_pending: constructs the exact scheduler state the flag targets (one in-flight decode + one queued prefill, below saturation) and asserts the new branch fires: decode advances by zero tokens that cycle when the flag is on, and advances normally when the flag is off. --- mlx_lm/generate.py | 25 ++++++++++++++-- tests/test_generate.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..6aa70f196 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1508,6 +1508,7 @@ def __init__( prefill_batch_size: int = 8, prefill_step_size: int = 2048, max_kv_size: Optional[int] = None, + prefer_prefill_when_pending: bool = False, stream=None, ): self.model = model @@ -1519,6 +1520,7 @@ def __init__( self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) self.max_kv_size = max_kv_size + self._prefer_prefill_when_pending = prefer_prefill_when_pending self._stream = stream or generation_stream @@ -1770,8 +1772,27 @@ def _next(self): generation_responses = [] prompt_responses = [] + # With prefer_prefill_when_pending=True, skip the decode step when + # any prefill work is queued or in flight, unless the decode batch is + # already saturated. This matters most for long-context multi-agent + # workloads where the prefill chunk (~1s+ at 32k context) and a + # single decode step (~10ms) share the same GPU: interleaving them + # 1:1 — the default scheduler — drops already-decoding requests to + # near-zero tok/s while another request is prefilling. Pausing the + # decode lets the batch decode together at native speed once + # prefill catches up. + pending_prefill = self._prefer_prefill_when_pending and ( + len(self._unprocessed_sequences) > 0 + or len(self._currently_processing) > 0 + or len(self._prompt_batch) > 0 + ) + saturated = len(self._generation_batch) >= self.completion_batch_size + do_decode = len(self._generation_batch) > 0 and ( + not pending_prefill or saturated + ) + # Generate tokens first - if len(self._generation_batch) > 0: + if do_decode: generation_responses = self._generation_batch.next() self._gen_tokens_counter += len(generation_responses) self._steps_counter += 1 @@ -1779,7 +1800,7 @@ def _next(self): mx.clear_cache() # Exit early because we already have our hands full with decoding - if len(self._generation_batch) >= self.completion_batch_size: + if saturated: return prompt_responses, generation_responses # Check if we have sequences and add them to the prompt batch diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..93a28b88a 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -806,6 +806,72 @@ def test_batch_max_kv_size_none_creates_regular_cache(self): for cache in r.prompt_cache: self.assertIsInstance(cache, KVCache) + def test_prefer_prefill_when_pending_default_false(self): + # Default behavior must be unchanged: the new flag defaults to False. + gen = BatchGenerator(self.model, max_tokens=1) + self.assertFalse(gen._prefer_prefill_when_pending) + + def test_prefer_prefill_when_pending_accepted_and_stored(self): + # Opting in stores the flag without affecting other init kwargs. + gen = BatchGenerator( + self.model, + max_tokens=1, + prefer_prefill_when_pending=True, + ) + self.assertTrue(gen._prefer_prefill_when_pending) + + def test_prefer_prefill_pauses_decode_when_prefill_pending(self): + # With the flag on, a step that has both queued prefill work and an + # in-flight decode batch (below saturation) should skip the decode + # this cycle and drain prefill first. With the flag off (default), + # the decode runs as usual. + prompt_a = self.tokenizer.encode("Write a long story about a cat") + prompt_b = self.tokenizer.encode("Write a long story about a dog") + + def run(prefer_prefill): + gen = BatchGenerator( + self.model, + max_tokens=5, + prefill_batch_size=1, + prefill_step_size=4, + completion_batch_size=4, + prefer_prefill_when_pending=prefer_prefill, + ) + # Insert prompt A and drive next() until prefill has completed + # and the sequence has been promoted into _generation_batch. + # With prefill_step_size=4 and a longer prompt, the first sequence + # needs multiple next() cycles to traverse prefill before it can + # start decoding. + gen.insert([prompt_a]) + for _ in range(20): + gen.next() + if len(gen._generation_batch) == 1: + break + self.assertEqual(len(gen._generation_batch), 1) + self.assertLess(len(gen._generation_batch), gen.completion_batch_size) + + # Queue a second prompt so prefill work is now pending while a + # decode batch is in flight and not yet saturated. + gen.insert([prompt_b]) + self.assertGreater( + len(gen._unprocessed_sequences) + + len(gen._currently_processing) + + len(gen._prompt_batch), + 0, + ) + + tokens_before = gen._gen_tokens_counter + gen.next() + return gen._gen_tokens_counter - tokens_before + + decoded_with_flag_off = run(prefer_prefill=False) + decoded_with_flag_on = run(prefer_prefill=True) + + # Flag off: the in-flight request gets a decode token this cycle. + self.assertGreater(decoded_with_flag_off, 0) + # Flag on: decode is paused so prefill can run first. + self.assertEqual(decoded_with_flag_on, 0) + if __name__ == "__main__": unittest.main()