Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,7 @@ def __init__(
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
max_kv_size: Optional[int] = None,
prefer_prefill_when_pending: bool = False,
stream=None,
):
self.model = model
Expand All @@ -1519,6 +1520,7 @@ def __init__(
self.prefill_batch_size = prefill_batch_size
self.completion_batch_size = max(completion_batch_size, prefill_batch_size)
self.max_kv_size = max_kv_size
self._prefer_prefill_when_pending = prefer_prefill_when_pending

self._stream = stream or generation_stream

Expand Down Expand Up @@ -1770,16 +1772,35 @@ def _next(self):
generation_responses = []
prompt_responses = []

# With prefer_prefill_when_pending=True, skip the decode step when
# any prefill work is queued or in flight, unless the decode batch is
# already saturated. This matters most for long-context multi-agent
# workloads where the prefill chunk (~1s+ at 32k context) and a
# single decode step (~10ms) share the same GPU: interleaving them
# 1:1 — the default scheduler — drops already-decoding requests to
# near-zero tok/s while another request is prefilling. Pausing the
# decode lets the batch decode together at native speed once
# prefill catches up.
pending_prefill = self._prefer_prefill_when_pending and (
len(self._unprocessed_sequences) > 0
or len(self._currently_processing) > 0
or len(self._prompt_batch) > 0
)
saturated = len(self._generation_batch) >= self.completion_batch_size
do_decode = len(self._generation_batch) > 0 and (
not pending_prefill or saturated
)

# Generate tokens first
if len(self._generation_batch) > 0:
if do_decode:
generation_responses = self._generation_batch.next()
self._gen_tokens_counter += len(generation_responses)
self._steps_counter += 1
if self._steps_counter % 512 == 0:
mx.clear_cache()

# Exit early because we already have our hands full with decoding
if len(self._generation_batch) >= self.completion_batch_size:
if saturated:
return prompt_responses, generation_responses

# Check if we have sequences and add them to the prompt batch
Expand Down
66 changes: 66 additions & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,72 @@ def test_batch_max_kv_size_none_creates_regular_cache(self):
for cache in r.prompt_cache:
self.assertIsInstance(cache, KVCache)

def test_prefer_prefill_when_pending_default_false(self):
# Default behavior must be unchanged: the new flag defaults to False.
gen = BatchGenerator(self.model, max_tokens=1)
self.assertFalse(gen._prefer_prefill_when_pending)

def test_prefer_prefill_when_pending_accepted_and_stored(self):
# Opting in stores the flag without affecting other init kwargs.
gen = BatchGenerator(
self.model,
max_tokens=1,
prefer_prefill_when_pending=True,
)
self.assertTrue(gen._prefer_prefill_when_pending)

def test_prefer_prefill_pauses_decode_when_prefill_pending(self):
# With the flag on, a step that has both queued prefill work and an
# in-flight decode batch (below saturation) should skip the decode
# this cycle and drain prefill first. With the flag off (default),
# the decode runs as usual.
prompt_a = self.tokenizer.encode("Write a long story about a cat")
prompt_b = self.tokenizer.encode("Write a long story about a dog")

def run(prefer_prefill):
gen = BatchGenerator(
self.model,
max_tokens=5,
prefill_batch_size=1,
prefill_step_size=4,
completion_batch_size=4,
prefer_prefill_when_pending=prefer_prefill,
)
# Insert prompt A and drive next() until prefill has completed
# and the sequence has been promoted into _generation_batch.
# With prefill_step_size=4 and a longer prompt, the first sequence
# needs multiple next() cycles to traverse prefill before it can
# start decoding.
gen.insert([prompt_a])
for _ in range(20):
gen.next()
if len(gen._generation_batch) == 1:
break
self.assertEqual(len(gen._generation_batch), 1)
self.assertLess(len(gen._generation_batch), gen.completion_batch_size)

# Queue a second prompt so prefill work is now pending while a
# decode batch is in flight and not yet saturated.
gen.insert([prompt_b])
self.assertGreater(
len(gen._unprocessed_sequences)
+ len(gen._currently_processing)
+ len(gen._prompt_batch),
0,
)

tokens_before = gen._gen_tokens_counter
gen.next()
return gen._gen_tokens_counter - tokens_before

decoded_with_flag_off = run(prefer_prefill=False)
decoded_with_flag_on = run(prefer_prefill=True)

# Flag off: the in-flight request gets a decode token this cycle.
self.assertGreater(decoded_with_flag_off, 0)
# Flag on: decode is paused so prefill can run first.
self.assertEqual(decoded_with_flag_on, 0)


if __name__ == "__main__":
unittest.main()
Loading