Skip to content

Generic cache snapshot/restore primitives + bit-exact PLD generator#8

Merged
benjamin-levin merged 1 commit into
mainfrom
pld-exact-rollback
May 19, 2026
Merged

Generic cache snapshot/restore primitives + bit-exact PLD generator#8
benjamin-levin merged 1 commit into
mainfrom
pld-exact-rollback

Conversation

@benjamin-levin
Copy link
Copy Markdown
Owner

@benjamin-levin benjamin-levin commented May 18, 2026

Summary

This PR makes two additive contributions:

  1. Generic snap() / restore(snap) primitives on every built-in cache class, plus module-level snapshot_prompt_cache / restore_prompt_cache helpers. Default _BaseCache implementation is a no-op, so any existing custom cache subclass keeps working unchanged.
  2. prompt_lookup_generate_step: a Prompt Lookup Decoding (PLD) generator that uses the snap/restore primitives above to do bit-exact rollback on partial accept. Wired through stream_generate(prompt_lookup_num_tokens=N) and the mlx_lm.generate CLI (--prompt-lookup-num-tokens). Mutually exclusive with draft_model.

Motivation

The existing speculative_generate_step (draft-model path) gates rollback on can_trim_prompt_cache, which excludes non-trimmable recurrent caches like ArraysCache used by Gated Delta Net / Mamba-style models. That makes PLD's natural rollback-on-partial-accept unavailable for those models without dropping back to lossy approximations.

Snapshot + restore + re-forward of the accepted prefix is bit-exact for every built-in cache type, including ArraysCache. Cost is one extra forward of size n_accept + 1 on partial-accept cycles; full-accept cycles drop the snapshot for free (snapshots are by-reference, no tensor copy).

Implementation

Cache snap/restore (mlx_lm/models/cache.py):

  • _BaseCache.snap() returns None and restore(None) is a no-op, so subclasses with no mutable state get correct behavior for free.
  • KVCache.snap() / QuantizedKVCache.snap() snapshot just offset -- writes happen at positions [offset, offset+N) in a pre-allocated buffer, so restoring the offset reverts the writes without copying.
  • RotatingKVCache.snap() snapshots (offset, _idx).
  • ChunkedKVCache.snap() snapshots (offset, start_position).
  • ConcatenateKVCache.snap() holds the pre-update keys / values references along with offset, because that cache replaces its arrays via mx.concatenate instead of in-place writes.
  • BatchKVCache.snap() / BatchRotatingKVCache.snap() snapshot the per-batch index and offset tuples.
  • ArraysCache.snap() returns (list(self.cache), self.lengths, self.left_padding). The model __call__ reassigns cache[i] by reference and advance() rebinds lengths / left_padding via -=, so a by-reference snapshot of the list and these two arrays is sufficient.
  • CacheList.snap() recursively snapshots its children.
  • Module-level snapshot_prompt_cache(cache_list) / restore_prompt_cache(cache_list, snap) iterate and dispatch.

PLD generator (mlx_lm/generate.py):

  • _pld_find_draft(generated, prompt, k_lookback, k_lookahead): right-to-left n-gram search returning the most recent prompt continuation after a match, or [] for AR fallback.
  • prompt_lookup_generate_step(...): per-cycle snapshot, single verify forward over [y, draft_1, ..., draft_k], accept the longest matching prefix, then either (a) full accept -> drop snapshot, or (b) partial accept -> restore_prompt_cache + re-forward [y, accepted...] so the cache lands at exactly "after y + accepted drafts".
  • Logits-processor history trimming on partial accept so processors see only the tokens that were actually committed.
  • Self-contained: no changes to speculative_generate_step.
  • Yields (int_token, mx.array_logprobs, bool_from_draft), matching the from_draft convention from the draft-model path.

Backwards compatibility

Strictly additive:

  • _BaseCache.snap() default returns None and restore(None) is a no-op. Existing custom cache subclasses keep working without modification.
  • stream_generate keeps prompt_lookup_num_tokens=None as default. Existing call sites are unaffected.
  • draft_model and prompt_lookup_num_tokens are validated to be mutually exclusive.

Benchmarks

All numbers from M4 Max, 4-bit MLX models, greedy sampling. Other speculative stacks (SnapKV, scheduler-fix, auto-spec router, persistent prompt cache) disabled via env vars so the PLD path is the only variable. prompt_lookup_num_tokens=5.

Cell 1 -- Trimmable KVCache (Qwen3-1.7B-4bit)

PLD on a standard KVCache model. This is the "baseline-possible" path: PLD already worked on trimmable caches before this PR via prompt-cache trimming.

workload ctx prompt_tokens AR tok/s PLD tok/s speedup
echo 2k 2062 226.7 239.6 1.06x
echo 8k 8134 157.6 298.8 1.90x
code-edit 2k 2118 225.7 186.0 0.82x
code-edit 8k 8278 159.0 121.2 0.76x
qa-short 2k 2067 226.2 326.8 1.45x
qa-short 8k 8139 157.7 243.5 1.54x

echo and qa-short (high prompt-substring reuse in the continuation) win as expected; code-edit regresses because most drafted spans miss after a few tokens and the verify-forward overhead doesn't pay back. Both directions match the published PLD literature -- workload selection matters, which is what the auto-spec router in PR #7 is for.

Cell 2 -- Non-trimmable ArraysCache (Qwen3.6-35B-A3B-4bit, GDN MoE)

This is the headline cell. Qwen3.6 uses ArraysCache (gated delta net state), which fails can_trim_prompt_cache, so PLD was not available at all before this PR's snap/restore primitives.

Cache types confirmed at load: ['ArraysCache', 'KVCache'], has non-trimmable cache: True.

workload ctx prompt_tokens AR tok/s PLD tok/s speedup
echo 2k 2018 111.0 186.0 1.68x
echo 8k 7963 104.3 152.7 1.46x
code-edit 2k 2170 110.4 155.0 1.40x
code-edit 8k 8239 104.0 136.3 1.31x
qa-short 2k 2018 127.3 102.8 0.81x
qa-short 8k 7963 119.7 97.5 0.81x

Both echo and code-edit win at every context. qa-short regresses because the answer is 7 tokens with no prompt-substring overlap, so the partial-accept re-forward overhead dominates. With PLD now reaching this cache class, the GDN/recurrent-cache model family gets speculative-decode wins it previously could not access.

Cell 3 -- Bit-exactness (PLD == AR token-for-token)

Greedy sampler, 5 prompts per model, 2 models (both cache classes covered):

model cache classes prompts pass
mlx-community/Qwen3-1.7B-4bit KVCache 5 / 5
mlx-community/Qwen3.6-35B-A3B-4bit ArraysCache, KVCache 5 / 5
total 10 / 10

Every PLD output sequence is identical to the AR output for the same prompt, length-for-length and token-for-token.

Cell 4 -- snap()/restore() overhead microbench

2000 iterations per measurement, times in microseconds.

cache class T_kv = 512 T_kv = 4096 T_kv = 32768
KVCache 0.04 0.04 0.04
QuantizedKVCache 0.05 0.05 0.04
RotatingKVCache 0.06 0.06 0.06
ChunkedKVCache 0.06 0.06 0.06
ConcatenateKVCache 0.07 0.07 0.07
ArraysCache 0.14 0.12 0.13
BatchKVCache 0.06 0.06 0.07
BatchRotatingKVCache 0.07 0.07 0.06

(Column = round-trip snap+restore microseconds.)

  • All built-in cache classes round-trip in well under a microsecond, independent of T_kv. Implementation holds: snapshots are by-reference (offset / index / cached array handles), not tensor copies.
  • ArraysCache is the slowest at ~0.13 μs because it snapshots three references (list(self.cache), lengths, left_padding); still 4-5 orders of magnitude below a single decode step.
  • Module-level helpers: snapshot_prompt_cache + restore_prompt_cache on a 4-layer KV stack -> 0.33 μs round-trip. Snap/restore overhead is negligible at the scale of PLD's per-cycle verify forward.

Tests

tests/test_models.py -- TestCacheSnapRestore:

  • test_base_cache_snap_is_noop
  • test_kv_cache_snap_restore
  • test_quantized_kv_cache_snap_restore
  • test_rotating_kv_cache_snap_restore
  • test_chunked_kv_cache_snap_restore
  • test_concatenate_kv_cache_snap_restore
  • test_arrays_cache_snap_restore
  • test_cache_list_snap_restore
  • test_snapshot_restore_helpers

Each round-trip test creates a small synthetic cache (no model load), takes a snapshot, performs additional updates, restores, and asserts the offset + arrays match the snapshot point bit-exactly.

tests/test_generate.py -- TestPromptLookupDecoding:

  • test_pld_find_draft_returns_match
  • test_pld_find_draft_no_match
  • test_pld_find_draft_picks_most_recent
  • test_pld_find_draft_insufficient_history
  • test_prompt_lookup_generate_step_yield_shape
  • test_prompt_lookup_generate_step_matches_ar (greedy PLD == greedy AR, token-for-token)
  • test_prompt_lookup_generate_step_rejects_bad_args
  • test_stream_generate_prompt_lookup
  • test_stream_generate_prompt_lookup_conflicts_with_draft

The AR-equivalence test uses a repetitive prompt so PLD actually exercises the verify + rollback path rather than always falling back to AR.

Scope

  • Kept deliberately separate from speculative_generate_step to avoid invasive changes to that code path. A natural follow-up is to lift its can_trim_prompt_cache guard and route partial-accept through restore_prompt_cache, which would extend draft-model speculative decoding to non-trimmable caches the same way. Left for a separate PR.
  • Snapshots assume cache implementations only append or replace by reference during intervening forwards. That holds for every built-in cache. A custom cache that mutates in-place across the verify forward would need to override snap() / restore() to copy the affected slice.

Relationship to PR #7

PR #7 (auto-speculative-router) carries its own PLD generator path and references the same snap/restore primitives. The overlap is acknowledged in that PR's body; this PR is the focused contribution of (a) the cache primitives and (b) a self-contained PLD step generator wired into stream_generate.

Adds a Prompt Lookup Decoding (PLD) generator that drafts from n-gram
matches in the prompt and verifies in a single forward, avoiding the
need for a separate draft model. Rollback on partial accept is exact
(snapshot + restore + re-forward of the accepted prefix), so PLD
output matches plain auto-regressive decoding under the same sampler.

The cache module gains generic snap() / restore() hooks plus
snapshot_prompt_cache / restore_prompt_cache helpers. Unlike
trim_prompt_cache, these work for non-trimmable recurrent caches
(ArraysCache used by Gated Delta Net / Mamba-style models), so PLD
is bit-exact across every built-in cache type.

Wiring:
- mlx_lm.generate gains --prompt-lookup-num-tokens.
- stream_generate accepts prompt_lookup_num_tokens=N (mutually
  exclusive with draft_model).

Tests:
- tests/test_models.py: snap/restore round-trip per cache class plus
  module-level helpers.
- tests/test_generate.py: n-gram lookup helper, step-generator yield
  shape, AR-equivalence under greedy sampler, arg validation, and
  stream_generate wiring.
@benjamin-levin benjamin-levin changed the title generate: add prompt lookup decoding with bit-exact rollback Generic cache snapshot/restore primitives + bit-exact PLD generator May 19, 2026
@benjamin-levin benjamin-levin marked this pull request as ready for review May 19, 2026 18:24
@benjamin-levin benjamin-levin merged commit 42954bd into main May 19, 2026
1 of 2 checks passed
benjamin-levin pushed a commit that referenced this pull request May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant