diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..6aa70f196 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1508,6 +1508,7 @@ def __init__( prefill_batch_size: int = 8, prefill_step_size: int = 2048, max_kv_size: Optional[int] = None, + prefer_prefill_when_pending: bool = False, stream=None, ): self.model = model @@ -1519,6 +1520,7 @@ def __init__( self.prefill_batch_size = prefill_batch_size self.completion_batch_size = max(completion_batch_size, prefill_batch_size) self.max_kv_size = max_kv_size + self._prefer_prefill_when_pending = prefer_prefill_when_pending self._stream = stream or generation_stream @@ -1770,8 +1772,27 @@ def _next(self): generation_responses = [] prompt_responses = [] + # With prefer_prefill_when_pending=True, skip the decode step when + # any prefill work is queued or in flight, unless the decode batch is + # already saturated. This matters most for long-context multi-agent + # workloads where the prefill chunk (~1s+ at 32k context) and a + # single decode step (~10ms) share the same GPU: interleaving them + # 1:1 — the default scheduler — drops already-decoding requests to + # near-zero tok/s while another request is prefilling. Pausing the + # decode lets the batch decode together at native speed once + # prefill catches up. + pending_prefill = self._prefer_prefill_when_pending and ( + len(self._unprocessed_sequences) > 0 + or len(self._currently_processing) > 0 + or len(self._prompt_batch) > 0 + ) + saturated = len(self._generation_batch) >= self.completion_batch_size + do_decode = len(self._generation_batch) > 0 and ( + not pending_prefill or saturated + ) + # Generate tokens first - if len(self._generation_batch) > 0: + if do_decode: generation_responses = self._generation_batch.next() self._gen_tokens_counter += len(generation_responses) self._steps_counter += 1 @@ -1779,7 +1800,7 @@ def _next(self): mx.clear_cache() # Exit early because we already have our hands full with decoding - if len(self._generation_batch) >= self.completion_batch_size: + if saturated: return prompt_responses, generation_responses # Check if we have sequences and add them to the prompt batch diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..93a28b88a 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -806,6 +806,72 @@ def test_batch_max_kv_size_none_creates_regular_cache(self): for cache in r.prompt_cache: self.assertIsInstance(cache, KVCache) + def test_prefer_prefill_when_pending_default_false(self): + # Default behavior must be unchanged: the new flag defaults to False. + gen = BatchGenerator(self.model, max_tokens=1) + self.assertFalse(gen._prefer_prefill_when_pending) + + def test_prefer_prefill_when_pending_accepted_and_stored(self): + # Opting in stores the flag without affecting other init kwargs. + gen = BatchGenerator( + self.model, + max_tokens=1, + prefer_prefill_when_pending=True, + ) + self.assertTrue(gen._prefer_prefill_when_pending) + + def test_prefer_prefill_pauses_decode_when_prefill_pending(self): + # With the flag on, a step that has both queued prefill work and an + # in-flight decode batch (below saturation) should skip the decode + # this cycle and drain prefill first. With the flag off (default), + # the decode runs as usual. + prompt_a = self.tokenizer.encode("Write a long story about a cat") + prompt_b = self.tokenizer.encode("Write a long story about a dog") + + def run(prefer_prefill): + gen = BatchGenerator( + self.model, + max_tokens=5, + prefill_batch_size=1, + prefill_step_size=4, + completion_batch_size=4, + prefer_prefill_when_pending=prefer_prefill, + ) + # Insert prompt A and drive next() until prefill has completed + # and the sequence has been promoted into _generation_batch. + # With prefill_step_size=4 and a longer prompt, the first sequence + # needs multiple next() cycles to traverse prefill before it can + # start decoding. + gen.insert([prompt_a]) + for _ in range(20): + gen.next() + if len(gen._generation_batch) == 1: + break + self.assertEqual(len(gen._generation_batch), 1) + self.assertLess(len(gen._generation_batch), gen.completion_batch_size) + + # Queue a second prompt so prefill work is now pending while a + # decode batch is in flight and not yet saturated. + gen.insert([prompt_b]) + self.assertGreater( + len(gen._unprocessed_sequences) + + len(gen._currently_processing) + + len(gen._prompt_batch), + 0, + ) + + tokens_before = gen._gen_tokens_counter + gen.next() + return gen._gen_tokens_counter - tokens_before + + decoded_with_flag_off = run(prefer_prefill=False) + decoded_with_flag_on = run(prefer_prefill=True) + + # Flag off: the in-flight request gets a decode token this cycle. + self.assertGreater(decoded_with_flag_off, 0) + # Flag on: decode is paused so prefill can run first. + self.assertEqual(decoded_with_flag_on, 0) + if __name__ == "__main__": unittest.main()