Generic cache snapshot/restore primitives + bit-exact PLD generator#8
Merged
Conversation
Adds a Prompt Lookup Decoding (PLD) generator that drafts from n-gram matches in the prompt and verifies in a single forward, avoiding the need for a separate draft model. Rollback on partial accept is exact (snapshot + restore + re-forward of the accepted prefix), so PLD output matches plain auto-regressive decoding under the same sampler. The cache module gains generic snap() / restore() hooks plus snapshot_prompt_cache / restore_prompt_cache helpers. Unlike trim_prompt_cache, these work for non-trimmable recurrent caches (ArraysCache used by Gated Delta Net / Mamba-style models), so PLD is bit-exact across every built-in cache type. Wiring: - mlx_lm.generate gains --prompt-lookup-num-tokens. - stream_generate accepts prompt_lookup_num_tokens=N (mutually exclusive with draft_model). Tests: - tests/test_models.py: snap/restore round-trip per cache class plus module-level helpers. - tests/test_generate.py: n-gram lookup helper, step-generator yield shape, AR-equivalence under greedy sampler, arg validation, and stream_generate wiring.
5b9a4bd to
3b2a129
Compare
benjamin-levin
pushed a commit
that referenced
this pull request
May 19, 2026
…uter stays in feature branch)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR makes two additive contributions:
snap()/restore(snap)primitives on every built-in cache class, plus module-levelsnapshot_prompt_cache/restore_prompt_cachehelpers. Default_BaseCacheimplementation is a no-op, so any existing custom cache subclass keeps working unchanged.prompt_lookup_generate_step: a Prompt Lookup Decoding (PLD) generator that uses the snap/restore primitives above to do bit-exact rollback on partial accept. Wired throughstream_generate(prompt_lookup_num_tokens=N)and themlx_lm.generateCLI (--prompt-lookup-num-tokens). Mutually exclusive withdraft_model.Motivation
The existing
speculative_generate_step(draft-model path) gates rollback oncan_trim_prompt_cache, which excludes non-trimmable recurrent caches likeArraysCacheused by Gated Delta Net / Mamba-style models. That makes PLD's natural rollback-on-partial-accept unavailable for those models without dropping back to lossy approximations.Snapshot + restore + re-forward of the accepted prefix is bit-exact for every built-in cache type, including
ArraysCache. Cost is one extra forward of sizen_accept + 1on partial-accept cycles; full-accept cycles drop the snapshot for free (snapshots are by-reference, no tensor copy).Implementation
Cache snap/restore (
mlx_lm/models/cache.py):_BaseCache.snap()returnsNoneandrestore(None)is a no-op, so subclasses with no mutable state get correct behavior for free.KVCache.snap()/QuantizedKVCache.snap()snapshot justoffset-- writes happen at positions[offset, offset+N)in a pre-allocated buffer, so restoring the offset reverts the writes without copying.RotatingKVCache.snap()snapshots(offset, _idx).ChunkedKVCache.snap()snapshots(offset, start_position).ConcatenateKVCache.snap()holds the pre-updatekeys/valuesreferences along withoffset, because that cache replaces its arrays viamx.concatenateinstead of in-place writes.BatchKVCache.snap()/BatchRotatingKVCache.snap()snapshot the per-batch index and offset tuples.ArraysCache.snap()returns(list(self.cache), self.lengths, self.left_padding). The model__call__reassignscache[i]by reference andadvance()rebindslengths/left_paddingvia-=, so a by-reference snapshot of the list and these two arrays is sufficient.CacheList.snap()recursively snapshots its children.snapshot_prompt_cache(cache_list)/restore_prompt_cache(cache_list, snap)iterate and dispatch.PLD generator (
mlx_lm/generate.py):_pld_find_draft(generated, prompt, k_lookback, k_lookahead): right-to-left n-gram search returning the most recent prompt continuation after a match, or[]for AR fallback.prompt_lookup_generate_step(...): per-cycle snapshot, single verify forward over[y, draft_1, ..., draft_k], accept the longest matching prefix, then either (a) full accept -> drop snapshot, or (b) partial accept ->restore_prompt_cache+ re-forward[y, accepted...]so the cache lands at exactly "after y + accepted drafts".speculative_generate_step.(int_token, mx.array_logprobs, bool_from_draft), matching thefrom_draftconvention from the draft-model path.Backwards compatibility
Strictly additive:
_BaseCache.snap()default returnsNoneandrestore(None)is a no-op. Existing custom cache subclasses keep working without modification.stream_generatekeepsprompt_lookup_num_tokens=Noneas default. Existing call sites are unaffected.draft_modelandprompt_lookup_num_tokensare validated to be mutually exclusive.Benchmarks
All numbers from M4 Max, 4-bit MLX models, greedy sampling. Other speculative stacks (SnapKV, scheduler-fix, auto-spec router, persistent prompt cache) disabled via env vars so the PLD path is the only variable.
prompt_lookup_num_tokens=5.Cell 1 -- Trimmable KVCache (Qwen3-1.7B-4bit)
PLD on a standard
KVCachemodel. This is the "baseline-possible" path: PLD already worked on trimmable caches before this PR via prompt-cache trimming.echoandqa-short(high prompt-substring reuse in the continuation) win as expected;code-editregresses because most drafted spans miss after a few tokens and the verify-forward overhead doesn't pay back. Both directions match the published PLD literature -- workload selection matters, which is what the auto-spec router in PR #7 is for.Cell 2 -- Non-trimmable
ArraysCache(Qwen3.6-35B-A3B-4bit, GDN MoE)This is the headline cell. Qwen3.6 uses
ArraysCache(gated delta net state), which failscan_trim_prompt_cache, so PLD was not available at all before this PR's snap/restore primitives.Cache types confirmed at load:
['ArraysCache', 'KVCache'], has non-trimmable cache:True.Both
echoandcode-editwin at every context.qa-shortregresses because the answer is 7 tokens with no prompt-substring overlap, so the partial-accept re-forward overhead dominates. With PLD now reaching this cache class, the GDN/recurrent-cache model family gets speculative-decode wins it previously could not access.Cell 3 -- Bit-exactness (PLD == AR token-for-token)
Greedy sampler, 5 prompts per model, 2 models (both cache classes covered):
Every PLD output sequence is identical to the AR output for the same prompt, length-for-length and token-for-token.
Cell 4 -- snap()/restore() overhead microbench
2000 iterations per measurement, times in microseconds.
(Column = round-trip snap+restore microseconds.)
T_kv. Implementation holds: snapshots are by-reference (offset / index / cached array handles), not tensor copies.ArraysCacheis the slowest at ~0.13 μs because it snapshots three references (list(self.cache),lengths,left_padding); still 4-5 orders of magnitude below a single decode step.snapshot_prompt_cache+restore_prompt_cacheon a 4-layer KV stack -> 0.33 μs round-trip. Snap/restore overhead is negligible at the scale of PLD's per-cycle verify forward.Tests
tests/test_models.py--TestCacheSnapRestore:test_base_cache_snap_is_nooptest_kv_cache_snap_restoretest_quantized_kv_cache_snap_restoretest_rotating_kv_cache_snap_restoretest_chunked_kv_cache_snap_restoretest_concatenate_kv_cache_snap_restoretest_arrays_cache_snap_restoretest_cache_list_snap_restoretest_snapshot_restore_helpersEach round-trip test creates a small synthetic cache (no model load), takes a snapshot, performs additional updates, restores, and asserts the offset + arrays match the snapshot point bit-exactly.
tests/test_generate.py--TestPromptLookupDecoding:test_pld_find_draft_returns_matchtest_pld_find_draft_no_matchtest_pld_find_draft_picks_most_recenttest_pld_find_draft_insufficient_historytest_prompt_lookup_generate_step_yield_shapetest_prompt_lookup_generate_step_matches_ar(greedy PLD == greedy AR, token-for-token)test_prompt_lookup_generate_step_rejects_bad_argstest_stream_generate_prompt_lookuptest_stream_generate_prompt_lookup_conflicts_with_draftThe AR-equivalence test uses a repetitive prompt so PLD actually exercises the verify + rollback path rather than always falling back to AR.
Scope
speculative_generate_stepto avoid invasive changes to that code path. A natural follow-up is to lift itscan_trim_prompt_cacheguard and route partial-accept throughrestore_prompt_cache, which would extend draft-model speculative decoding to non-trimmable caches the same way. Left for a separate PR.snap()/restore()to copy the affected slice.Relationship to PR #7
PR #7 (
auto-speculative-router) carries its own PLD generator path and references the same snap/restore primitives. The overlap is acknowledged in that PR's body; this PR is the focused contribution of (a) the cache primitives and (b) a self-contained PLD step generator wired intostream_generate.