Opt-in prompt-lookup decoding + auto-speculative router#7
Merged
Conversation
Adds two new generators to mlx_lm.generate (both opt-in, default off):
* prompt_lookup_generate_step — PLD speculative decoding (drafts via
n-gram lookup against the prompt; no draft model required). Verifies
in one main-model forward, accepts the greedy prefix, trims the cache
on partial accept.
* auto_speculative_generate_step — routes between PLD and plain AR
based on prompt length, n-gram density, and a 16-token PLD probe.
Probe-failure path falls back to AR on the warm cache so the prefill
+ probe cost is paid once.
Wiring. stream_generate accepts two new kwargs:
* auto_speculative=True (default False) — route via the auto router.
* prompt_lookup_num_tokens=N — use PLD directly without the router.
Both are mutually exclusive with draft_model=. CLI exposes the matching
--auto-speculative and --prompt-lookup-num-tokens flags.
Motivation. On an Apple-silicon companion fork (mlx_fast) the same
router measured +17-25% across echo-heavy / code-edit / open-gen /
qa-short workloads vs plain AR, with bit-exact greedy output:
Qwen3.6-35B-A3B-4bit on M4 Max 36GB (N=1):
echo 1.17x
code-edit 1.08x
open-gen 1.25x
qa-short 1.21x
Cross-route correctness: 96/96 prompts bit-exact vs AR.
Trade-off. PLD has a per-cycle setup cost (the verify forward sees
1 + k_lookahead tokens instead of 1) — net wins require either a
successful draft or amortized prefill via the warm-cache fallback.
The router's length pre-filter + early-bail probe keeps the worst case
at "AR plus one probe" rather than "AR with PLD overhead on every
cycle."
Tests. tests/test_generate.py adds a TestAutoSpeculative class
covering: default off (behavior unchanged), CLI flag parsing
(--auto-speculative, --prompt-lookup-num-tokens), helper correctness
(_pld_find_draft longest-suffix match / empty / no-continuation;
_auto_spec_score short-prompt zero / long-repetitive positive / unit
interval), stream_generate auto-speculative short-prompt smoke, and
mutual-exclusion with draft_model.
Note. mlx-lm doesn't ship an MTP draft-head primitive, so the
companion fork's 3-way AR/MTP/PLD router collapses here into the
2-way AR/PLD router. The router shape stays identical so an MTP arm
can drop in later.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
9960b66 to
0a8e1f4
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds two new generators to
mlx_lm.generate(both opt-in, default off) plus a thin CLI flag layer:prompt_lookup_generate_step— PLD speculative decoding: drafts the next k tokens by n-gram lookup against the prompt, verifies in one main-model forward, accepts the greedy prefix, trims the cache on partial accept. No draft model required.auto_speculative_generate_step— routes between PLD and plain AR based on prompt length, n-gram density, and a 16-token PLD probe. Probe-failure path falls back to AR on the warm cache so the prefill + probe cost is paid once.Motivation
PLD is a strong win on prompt-echoing workloads (code edits, RAG with verbatim quotes, translation passthrough) and a measurable loss on free-form generation. The right default for "I don't know what the prompt looks like" workloads (e.g., an OpenAI-compatible server) is therefore an opt-in router that pays a fixed probe cost on every long prompt and then commits to PLD only when it will pay off.
Measured impact (Qwen3.6-35B-A3B-4bit on M4 Max 36GB, N=1, MAX_TOKENS=96, greedy)
Primary matrix: AR vs direct-PLD vs auto-router
Bench:
mlx_fast/bench/auto_speculative_bench.py. The bench harness exercises the upstreammlx_fast.auto_speculativerouter (which composes a trained MTP head, this PR's PLD generator, and AR). The PLD column corresponds to this PR'sprompt_lookup_generate_step; the AUTO column corresponds to the 3-way upstream router and is included for reference because it sets an upper bound on what a 2-way (PLD/AR) router can do.PLD is bit-exact across all four workloads (96/96 token match against AR). Direct PLD wins +19% on code-edit and is within ±5% of AR elsewhere.
PLD k-sweep on smaller prompts
Bench:
mlx_fast/bench/pld.py(Qwen3.6-35B-A3B-4bit, N=1, MAX_TOKENS=96).Confirms: PLD scales with prompt-echo density. The PR's default of
prompt_lookup_num_tokens=8is conservative;k=2-4is competitive for shorter structured prompts.What this means for the PR's 2-way router
PR #7 drops the MTP arm (mlx-lm doesn't ship an MTP draft-head primitive), so
auto_speculative_generate_stephere routes between PLD and AR only. Mapping the matrix above onto the 2-way router:The 16-token probe cost is small (PLD on a non-matching prompt is one extra verify forward per failed draft attempt; in the 3-way bench
AUTO − MTP ≈ −0.8%on echo-heavy and−4.5%on code-edit, whereMTPwas the continuation path). Net: on this 4-workload set the 2-way router is a no-op (probe bails, falls back to AR on warm cache) and is bit-exact with AR.Direct PLD via
prompt_lookup_num_tokens=Nis where the visible win lives in this PR: +19% on code-edit and within ±5% on the other workloads, bit-exact in all 12 cases. The router's value is on workloads where the probe does clear the 0.30 threshold (long, highly-echoing prompts: long-context RAG with verbatim quoting, large code-diff prompts, translation passthrough). Those are not in this bench set.Off-target / probe-overhead
Short-prompt workloads (open-gen plen=18, qa-short plen=12) score 0.0 in
_auto_spec_scoreand skip the probe entirely — the router becomes a singleif-check and AR runs unchanged. This is the "AUTO never loses to AR" guarantee on prompts under 256 tokens.For long prompts where probe acceptance comes in below 0.30 (echo-heavy, code-edit in this bench), the router pays one 16-token PLD probe and then falls back to AR on the warm cache. The probe-failure path was measured at ≈1-5% of decode time depending on prompt length.
Implementation
Files changed:
mlx_lm/generate.py(+553 / -7),tests/test_generate.py(+203).New module-public functions:
prompt_lookup_generate_step(prompt, model, prompt_ids, *, prompt_lookup_num_tokens=8, prompt_lookup_max_matches=2, ...)— mirrorsspeculative_generate_step's shape but the draft comes from_pld_find_draft(longest-suffix n-gram match against the prompt) instead of a draft model.auto_speculative_generate_step(prompt, model, *, ...)— the router. Cheap length+n-gram pre-filter, then probes with PLD, gates on acceptance, continues with PLD or falls back to AR using the warm cache._pld_find_draft,_auto_spec_score— internal helpers (covered by tests).Kwargs (
stream_generate/generate):auto_speculative=True— route via the auto router.prompt_lookup_num_tokens=N— use PLD directly without the router.CLI flags (matching the kwargs):
--auto-speculative--prompt-lookup-num-tokens NBackwards compatibility
auto_speculative=False,prompt_lookup_num_tokens=None). When unset,stream_generatetakes the same code path as before — covered bytest_stream_generate_default_behavior_unchanged.auto_speculative=Trueandprompt_lookup_num_tokens=Nare mutually exclusive withdraft_model=; passing both raisesValueError. Covered bytest_stream_generate_auto_speculative_rejects_draft_modelandtest_stream_generate_prompt_lookup_rejects_draft_model.Scope
cache.trim_prompt_cache. The router detects non-trimmable caches (e.g. recurrent caches) and falls back to AR with the warm cache instead of erroring; direct PLD viaprompt_lookup_num_tokens=Nraises a clearValueErrorin that case.Routing defaults (conservative)
_AUTO_SPEC_SHORT_LEN=256— below this, skip PLD entirely._AUTO_SPEC_LONG_LEN=1024— at/above this, full score weight on length._AUTO_SPEC_PROBE_TOKENS=16— probe budget._AUTO_SPEC_PROBE_THRESHOLD=0.30— acceptance rate needed to commit to PLD._AUTO_SPEC_PROBE_EARLY_BAIL=4— consecutive misses, abort probe early.Tests
tests/test_generate.py::TestAutoSpeculativecovers:test_module_exports_new_symbols— all four new symbols importable and callable.test_setup_arg_parser_defaults/..._auto_speculative_flag/..._prompt_lookup_flag— CLI parsing for both flags + defaults (False / None).test_pld_find_draft_basic_match/..._no_match_returns_empty/..._empty_inputs/..._prefers_longer_match/..._no_continuation—_pld_find_draftcorrectness on synthetic prompts.test_auto_spec_score_short_prompt_is_zero/..._long_prompt_is_positive/..._in_unit_interval—_auto_spec_scorebounds and length-gating.test_stream_generate_default_behavior_unchanged— token-for-token match between baseline andauto_speculative=False, prompt_lookup_num_tokens=Noneunder determinate sampler.test_stream_generate_auto_speculative_short_prompt— short-prompt routing path smoke (falls back to AR internally).test_stream_generate_auto_speculative_rejects_draft_model/..._prompt_lookup_rejects_draft_model— both new kwargs raiseValueErrorwhen combined withdraft_model=.Relationship to PR #8
PR #8 (
pld-exact-rollback) also adds aprompt_lookup_generate_step+_pld_find_draft, with different design choices:cache.trim_prompt_cachefor PLD rollback (requires trimmable cache; falls back to AR otherwise).The function signatures differ (
(generated, prompt, k_lookback, k_lookahead)vs(generated, prompt_ids, k_lookback, k_lookahead)here; PR #8 takesprompt_lookup_min_match, this PR takesprompt_lookup_max_matches). If PR #8 lands first, this PR should be rebased to reuse PR #8's_pld_find_draft+prompt_lookup_generate_stepand contribute only the auto-router on top. Filed as a known follow-up.Fork-only / not for upstream
This PR is for the
benjamin-levin/mlx-lmfork. Not intended for upstream submission as-is — see the PR #8 relationship note above; the right upstream story is "PR #8's bit-exact PLD generator + this PR's auto-router on top," squashed into one feature.Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com