Add opt-in prompt-lookup decoding + auto-speculative router#1286
Closed
benjamin-levin wants to merge 5 commits into
Closed
Add opt-in prompt-lookup decoding + auto-speculative router#1286benjamin-levin wants to merge 5 commits into
benjamin-levin wants to merge 5 commits into
Conversation
Empty commit to trigger pull_request.yml workflow registration on the fork. No source changes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds the fork to the if-gate for both jobs in pull_request.yml. Lets PR CI on the fork run against the self-hosted M4 Max runner registered on this fork. DROP THIS COMMIT BEFORE UPSTREAM SUBMISSION. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Public forks suppress push/pull_request workflow events by default; adding workflow_dispatch lets us manually trigger CI for the fork. DROP BEFORE UPSTREAM SUBMISSION. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two new generators to mlx_lm.generate (both opt-in, default off):
* `prompt_lookup_generate_step` — PLD speculative decoding (drafts via
n-gram lookup against the prompt; no draft model required). Verifies
in one main-model forward, accepts the greedy prefix, trims the cache
on partial accept.
* `auto_speculative_generate_step` — routes between PLD and plain AR
based on prompt length, n-gram density, and a 16-token PLD probe.
Probe-failure path falls back to AR on the warm cache so the prefill
+ probe cost is paid once.
Wiring. `stream_generate` accepts two new kwargs:
* `auto_speculative=True` (default False) — route via the auto router.
* `prompt_lookup_num_tokens=N` — use PLD directly without the router.
Both are mutually exclusive with `draft_model=`. CLI exposes the
matching `--auto-speculative` and `--prompt-lookup-num-tokens` flags.
Motivation. On an Apple-silicon companion fork (mlx_fast) the same
router measured +17-25% across echo-heavy / code-edit / open-gen /
qa-short workloads vs plain AR, with bit-exact greedy output:
Qwen3.6-35B-A3B-4bit on M4 Max 36GB (N=1):
echo 1.17x
code-edit 1.08x
open-gen 1.25x
qa-short 1.21x
Cross-route correctness: 96/96 prompts bit-exact vs AR.
Trade-off. PLD has a per-cycle setup cost (the verify forward sees
1 + k_lookahead tokens instead of 1) — net wins require either a
successful draft or amortized prefill via the warm-cache fallback.
The router's length pre-filter + early-bail probe keeps the worst
case at "AR plus one probe" rather than "AR with PLD overhead on
every cycle."
Note. mlx-lm doesn't ship an MTP draft-head primitive, so the
companion fork's 3-way AR/MTP/PLD router collapses here into the
2-way AR/PLD router. The router shape stays identical so an MTP
arm can drop in later.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Author
|
Wrong target — recreating against benjamin-levin/mlx-lm for fork CI validation. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
WIP / fork validation PR. Pre-existing
mac_build_and_testfailures (transformers 5.x tokenizer compatibility) are unrelated to this change.Summary
Adds two new generators to
mlx_lm.generate(both opt-in, default off):prompt_lookup_generate_step— PLD speculative decoding: drafts the next k tokens by n-gram lookup against the prompt, verifies in one main-model forward, accepts the greedy prefix, trims the cache on partial accept. No draft model required.auto_speculative_generate_step— routes between PLD and plain AR based on prompt length, n-gram density, and a 16-token PLD probe. Probe-failure path falls back to AR on the warm cache so the prefill + probe cost is paid once.Motivation
On an Apple-silicon companion fork (
mlx_fast.auto_speculative) the same router measured +17-25% across echo-heavy / code-edit / open-gen / qa-short workloads vs plain AR, with bit-exact greedy output (96/96 prompts).The router pattern is the right default for "I don't know what the prompt looks like" workloads (e.g. an OpenAI-compatible server) because:
Measured impact (Qwen3.6-35B-A3B-4bit on M4 Max 36GB, N=1)
From companion-repo benchmarks (
mlx_fast/auto_speculative.py):Implementation
Files changed:
mlx_lm/generate.py(+559 / -7).New functions (all module-public):
prompt_lookup_generate_step(prompt, model, prompt_ids, *, prompt_lookup_num_tokens=8, prompt_lookup_max_matches=2, ...)— mirrorsspeculative_generate_step's shape but the draft comes from_pld_find_draft(longest-suffix n-gram match against the prompt) instead of a draft model.auto_speculative_generate_step(prompt, model, *, ...)— the router. Cheap length+n-gram pre-filter, then probes with PLD, gates on acceptance, continues with PLD or falls back to AR using the warm cache._pld_find_draft,_auto_spec_score— internal helpers.Opt-in flags (default off, matches the
prefer-prefill-scheduleropt-in pattern):stream_generate/generateaccept:auto_speculative=True— route via the auto router.prompt_lookup_num_tokens=N— use PLD directly without the router.Both are mutually exclusive with
draft_model=. CLI exposes the matching--auto-speculativeand--prompt-lookup-num-tokensflags.Cache handling: PLD requires a trimmable prompt cache; the router falls back to AR (with the same warm cache) when the cache type isn't trimmable.
Trade-off
PLD has a per-cycle setup cost (the verify forward sees
1 + k_lookaheadtokens instead of1). Net wins require either successful drafts or amortized prefill via the warm-cache fallback. The router's length pre-filter + early-bail probe keeps the worst case at "AR plus one probe" rather than "AR with PLD overhead on every cycle."Defaults are conservative:
_AUTO_SPEC_SHORT_LEN=256— below this, skip PLD entirely._AUTO_SPEC_PROBE_TOKENS=16— probe budget._AUTO_SPEC_PROBE_THRESHOLD=0.30— acceptance rate needed to commit to PLD._AUTO_SPEC_PROBE_EARLY_BAIL=4— consecutive misses → abort probe early.Test plan
_pld_find_draftreturns longest-suffix match, returns[]when no match or no continuation, handles empty inputs._auto_spec_scorereturns 0 for prompts below 256 tokens, ramps to 1 for long+repetitive prompts.setup_arg_parserparses--auto-speculativeand--prompt-lookup-num-tokens; defaults areFalseandNone.mlx_lm.generatemodule imports cleanly; all four new functions resolve.stream_generaterejectsauto_speculative=Truetogether withdraft_model=.mac_build_and_test(will fail on pre-existing transformers tokenizer tests, unrelated).check_lint.Companion repo
The reference implementation that motivated this PR — including 3-way AR/MTP/PLD routing, exact GDN rollback on partial accept, and the +17-25% measurements — lives in
mlx_fast/auto_speculative.py. mlx-lm doesn't ship an MTP draft-head primitive, so the companion's 3-way router collapses here into the 2-way AR/PLD router; the router shape is preserved so an MTP arm can drop in later.🤖 Generated with Claude Code