Skip to content

Opt-in prompt-lookup decoding + auto-speculative router#7

Merged
benjamin-levin merged 1 commit into
mainfrom
auto-speculative-router
May 19, 2026
Merged

Opt-in prompt-lookup decoding + auto-speculative router#7
benjamin-levin merged 1 commit into
mainfrom
auto-speculative-router

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

Summary

Adds two new generators to mlx_lm.generate (both opt-in, default off) plus a thin CLI flag layer:

  • prompt_lookup_generate_step — PLD speculative decoding: drafts the next k tokens by n-gram lookup against the prompt, verifies in one main-model forward, accepts the greedy prefix, trims the cache on partial accept. No draft model required.
  • auto_speculative_generate_step — routes between PLD and plain AR based on prompt length, n-gram density, and a 16-token PLD probe. Probe-failure path falls back to AR on the warm cache so the prefill + probe cost is paid once.

Motivation

PLD is a strong win on prompt-echoing workloads (code edits, RAG with verbatim quotes, translation passthrough) and a measurable loss on free-form generation. The right default for "I don't know what the prompt looks like" workloads (e.g., an OpenAI-compatible server) is therefore an opt-in router that pays a fixed probe cost on every long prompt and then commits to PLD only when it will pay off.

Measured impact (Qwen3.6-35B-A3B-4bit on M4 Max 36GB, N=1, MAX_TOKENS=96, greedy)

Primary matrix: AR vs direct-PLD vs auto-router

Bench: mlx_fast/bench/auto_speculative_bench.py. The bench harness exercises the upstream mlx_fast.auto_speculative router (which composes a trained MTP head, this PR's PLD generator, and AR). The PLD column corresponds to this PR's prompt_lookup_generate_step; the AUTO column corresponds to the 3-way upstream router and is included for reference because it sets an upper bound on what a 2-way (PLD/AR) router can do.

workload plen AR tok/s PLD tok/s PLD vs AR 3-way AUTO tok/s 3-way AUTO vs AR bit-exact
echo-heavy 734 61.3 55.4 0.90x 71.5 1.17x 96/96
code-edit 1167 51.9 61.9 1.19x 56.0 1.08x 96/96
open-gen 18 90.2 89.5 0.99x 112.9 1.25x 96/96
qa-short 12 91.0 86.0 0.95x 109.9 1.21x 96/96

PLD is bit-exact across all four workloads (96/96 token match against AR). Direct PLD wins +19% on code-edit and is within ±5% of AR elsewhere.

PLD k-sweep on smaller prompts

Bench: mlx_fast/bench/pld.py (Qwen3.6-35B-A3B-4bit, N=1, MAX_TOKENS=96).

workload plen AR tok/s PLD k=2 PLD k=4 PLD k=8
repeat 30 89.9 1.17x 1.19x 1.13x
code-edit 58 86.7 1.12x 1.08x 1.06x
free-form 16 90.9 0.98x 0.97x 0.95x

Confirms: PLD scales with prompt-echo density. The PR's default of prompt_lookup_num_tokens=8 is conservative; k=2-4 is competitive for shorter structured prompts.

What this means for the PR's 2-way router

PR #7 drops the MTP arm (mlx-lm doesn't ship an MTP draft-head primitive), so auto_speculative_generate_step here routes between PLD and AR only. Mapping the matrix above onto the 2-way router:

workload plen score probe acc router decision expected vs AR
echo-heavy 734 0.74 0.00 bail → AR ≈ AR − (16-token probe)
code-edit 1167 1.00 0.08 bail → AR ≈ AR − (16-token probe)
open-gen 18 0.00 n/a score=0 → AR = AR (no probe run)
qa-short 12 0.00 n/a score=0 → AR = AR (no probe run)

The 16-token probe cost is small (PLD on a non-matching prompt is one extra verify forward per failed draft attempt; in the 3-way bench AUTO − MTP ≈ −0.8% on echo-heavy and −4.5% on code-edit, where MTP was the continuation path). Net: on this 4-workload set the 2-way router is a no-op (probe bails, falls back to AR on warm cache) and is bit-exact with AR.

Direct PLD via prompt_lookup_num_tokens=N is where the visible win lives in this PR: +19% on code-edit and within ±5% on the other workloads, bit-exact in all 12 cases. The router's value is on workloads where the probe does clear the 0.30 threshold (long, highly-echoing prompts: long-context RAG with verbatim quoting, large code-diff prompts, translation passthrough). Those are not in this bench set.

Off-target / probe-overhead

Short-prompt workloads (open-gen plen=18, qa-short plen=12) score 0.0 in _auto_spec_score and skip the probe entirely — the router becomes a single if-check and AR runs unchanged. This is the "AUTO never loses to AR" guarantee on prompts under 256 tokens.

For long prompts where probe acceptance comes in below 0.30 (echo-heavy, code-edit in this bench), the router pays one 16-token PLD probe and then falls back to AR on the warm cache. The probe-failure path was measured at ≈1-5% of decode time depending on prompt length.

Implementation

Files changed: mlx_lm/generate.py (+553 / -7), tests/test_generate.py (+203).

New module-public functions:

  • prompt_lookup_generate_step(prompt, model, prompt_ids, *, prompt_lookup_num_tokens=8, prompt_lookup_max_matches=2, ...) — mirrors speculative_generate_step's shape but the draft comes from _pld_find_draft (longest-suffix n-gram match against the prompt) instead of a draft model.
  • auto_speculative_generate_step(prompt, model, *, ...) — the router. Cheap length+n-gram pre-filter, then probes with PLD, gates on acceptance, continues with PLD or falls back to AR using the warm cache.
  • _pld_find_draft, _auto_spec_score — internal helpers (covered by tests).

Kwargs (stream_generate / generate):

  • auto_speculative=True — route via the auto router.
  • prompt_lookup_num_tokens=N — use PLD directly without the router.

CLI flags (matching the kwargs):

  • --auto-speculative
  • --prompt-lookup-num-tokens N

Backwards compatibility

  • Both kwargs default off (auto_speculative=False, prompt_lookup_num_tokens=None). When unset, stream_generate takes the same code path as before — covered by test_stream_generate_default_behavior_unchanged.
  • auto_speculative=True and prompt_lookup_num_tokens=N are mutually exclusive with draft_model=; passing both raises ValueError. Covered by test_stream_generate_auto_speculative_rejects_draft_model and test_stream_generate_prompt_lookup_rejects_draft_model.
  • No existing function signatures change; no existing tests touched.

Scope

  • MTP arm dropped: the companion fork's 3-way AR/MTP/PLD router collapses to a 2-way AR/PLD router here because mlx-lm doesn't ship an MTP draft-head primitive. The router shape is preserved so an MTP arm can drop in later. See the matrix above for the gap between 2-way (this PR) and 3-way (companion fork).
  • Requires a trimmable prompt cache: PLD rewinds the cache on partial accept via cache.trim_prompt_cache. The router detects non-trimmable caches (e.g. recurrent caches) and falls back to AR with the warm cache instead of erroring; direct PLD via prompt_lookup_num_tokens=N raises a clear ValueError in that case.
  • AR fallback on probe failure: if the 16-token PLD probe doesn't clear the acceptance threshold (default 0.30) or hits the early-bail counter (4 consecutive misses), the router closes the PLD generator and continues with plain AR from the warm cache. Net cost on the worst path is one prefill + probe.

Routing defaults (conservative)

  • _AUTO_SPEC_SHORT_LEN=256 — below this, skip PLD entirely.
  • _AUTO_SPEC_LONG_LEN=1024 — at/above this, full score weight on length.
  • _AUTO_SPEC_PROBE_TOKENS=16 — probe budget.
  • _AUTO_SPEC_PROBE_THRESHOLD=0.30 — acceptance rate needed to commit to PLD.
  • _AUTO_SPEC_PROBE_EARLY_BAIL=4 — consecutive misses, abort probe early.

Tests

tests/test_generate.py::TestAutoSpeculative covers:

  • test_module_exports_new_symbols — all four new symbols importable and callable.
  • test_setup_arg_parser_defaults / ..._auto_speculative_flag / ..._prompt_lookup_flag — CLI parsing for both flags + defaults (False / None).
  • test_pld_find_draft_basic_match / ..._no_match_returns_empty / ..._empty_inputs / ..._prefers_longer_match / ..._no_continuation_pld_find_draft correctness on synthetic prompts.
  • test_auto_spec_score_short_prompt_is_zero / ..._long_prompt_is_positive / ..._in_unit_interval_auto_spec_score bounds and length-gating.
  • test_stream_generate_default_behavior_unchanged — token-for-token match between baseline and auto_speculative=False, prompt_lookup_num_tokens=None under determinate sampler.
  • test_stream_generate_auto_speculative_short_prompt — short-prompt routing path smoke (falls back to AR internally).
  • test_stream_generate_auto_speculative_rejects_draft_model / ..._prompt_lookup_rejects_draft_model — both new kwargs raise ValueError when combined with draft_model=.

Relationship to PR #8

PR #8 (pld-exact-rollback) also adds a prompt_lookup_generate_step + _pld_find_draft, with different design choices:

The function signatures differ ((generated, prompt, k_lookback, k_lookahead) vs (generated, prompt_ids, k_lookback, k_lookahead) here; PR #8 takes prompt_lookup_min_match, this PR takes prompt_lookup_max_matches). If PR #8 lands first, this PR should be rebased to reuse PR #8's _pld_find_draft + prompt_lookup_generate_step and contribute only the auto-router on top. Filed as a known follow-up.

Fork-only / not for upstream

This PR is for the benjamin-levin/mlx-lm fork. Not intended for upstream submission as-is — see the PR #8 relationship note above; the right upstream story is "PR #8's bit-exact PLD generator + this PR's auto-router on top," squashed into one feature.

Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com

Adds two new generators to mlx_lm.generate (both opt-in, default off):

  * prompt_lookup_generate_step — PLD speculative decoding (drafts via
    n-gram lookup against the prompt; no draft model required). Verifies
    in one main-model forward, accepts the greedy prefix, trims the cache
    on partial accept.

  * auto_speculative_generate_step — routes between PLD and plain AR
    based on prompt length, n-gram density, and a 16-token PLD probe.
    Probe-failure path falls back to AR on the warm cache so the prefill
    + probe cost is paid once.

Wiring. stream_generate accepts two new kwargs:
  * auto_speculative=True (default False) — route via the auto router.
  * prompt_lookup_num_tokens=N — use PLD directly without the router.

Both are mutually exclusive with draft_model=. CLI exposes the matching
--auto-speculative and --prompt-lookup-num-tokens flags.

Motivation. On an Apple-silicon companion fork (mlx_fast) the same
router measured +17-25% across echo-heavy / code-edit / open-gen /
qa-short workloads vs plain AR, with bit-exact greedy output:

  Qwen3.6-35B-A3B-4bit on M4 Max 36GB (N=1):
    echo       1.17x
    code-edit  1.08x
    open-gen   1.25x
    qa-short   1.21x
  Cross-route correctness: 96/96 prompts bit-exact vs AR.

Trade-off. PLD has a per-cycle setup cost (the verify forward sees
1 + k_lookahead tokens instead of 1) — net wins require either a
successful draft or amortized prefill via the warm-cache fallback.
The router's length pre-filter + early-bail probe keeps the worst case
at "AR plus one probe" rather than "AR with PLD overhead on every
cycle."

Tests. tests/test_generate.py adds a TestAutoSpeculative class
covering: default off (behavior unchanged), CLI flag parsing
(--auto-speculative, --prompt-lookup-num-tokens), helper correctness
(_pld_find_draft longest-suffix match / empty / no-continuation;
_auto_spec_score short-prompt zero / long-repetitive positive / unit
interval), stream_generate auto-speculative short-prompt smoke, and
mutual-exclusion with draft_model.

Note. mlx-lm doesn't ship an MTP draft-head primitive, so the
companion fork's 3-way AR/MTP/PLD router collapses here into the
2-way AR/PLD router. The router shape stays identical so an MTP arm
can drop in later.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@benjamin-levin benjamin-levin force-pushed the auto-speculative-router branch from 9960b66 to 0a8e1f4 Compare May 19, 2026 00:11
@benjamin-levin benjamin-levin changed the title Add opt-in prompt-lookup decoding + auto-speculative router Opt-in prompt-lookup decoding + auto-speculative router May 19, 2026
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:24
@benjamin-levin benjamin-levin merged commit 49cfd22 into main May 19, 2026
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant