Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 269 additions & 1 deletion mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ def setup_arg_parser():
help="Number of tokens to draft when using speculative decoding.",
default=3,
)
parser.add_argument(
"--prompt-lookup-num-tokens",
type=int,
help="If set, use Prompt Lookup Decoding (PLD) drafting this many "
"tokens per cycle from prompt n-grams. No draft model required. "
"Mutually exclusive with --draft-model.",
default=None,
)
return parser


Expand Down Expand Up @@ -654,12 +662,252 @@ def _draft_generate(y, num_draft):
_rewind_cache(num_draft, n)


def _pld_find_draft(
generated: List[int],
prompt: List[int],
k_lookback: int,
k_lookahead: int,
) -> List[int]:
"""
Find a PLD draft: match the last ``k_lookback`` generated tokens against
the prompt right-to-left and return the next ``k_lookahead`` prompt
tokens after the most recent match. Empty list if no match (caller
falls back to plain auto-regressive decoding).
"""
if len(generated) < k_lookback:
return []
suffix = generated[-k_lookback:]
for i in range(len(prompt) - k_lookback, -1, -1):
if prompt[i : i + k_lookback] == suffix:
return prompt[i + k_lookback : i + k_lookback + k_lookahead]
return []


def prompt_lookup_generate_step(
prompt: mx.array,
model: nn.Module,
*,
prompt_lookup_num_tokens: int = 5,
prompt_lookup_min_match: int = 3,
max_tokens: int = 256,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""
Prompt Lookup Decoding (PLD): a training-free speculative-decoding
variant where the "draft" is an n-gram lookup in the prompt rather than
a separate draft model. At each step the last few generated tokens are
matched against the prompt; the next ``prompt_lookup_num_tokens`` tokens
in the prompt after the match are verified in a single forward pass.

Acceptance is highest when generation reproduces prompt content
(translation, code edit, summarisation, RAG answer); otherwise PLD
falls back to plain auto-regressive decoding for that step.

Rollback on partial accept uses snapshot + restore + re-forward of the
accepted prefix. This is bit-exact for both KV caches and non-trimmable
recurrent caches (e.g. ``ArraysCache`` used by Gated Delta Net /
Mamba-style models), so PLD output matches plain AR decoding under the
same sampler.

Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
prompt_lookup_num_tokens (int): Lookahead depth -- maximum number of
tokens drafted from the prompt per cycle. Default: ``5``.
prompt_lookup_min_match (int): N-gram length used to match the
suffix of the generated text against the prompt. Default: ``3``.
max_tokens (int): Maximum number of tokens to generate. Default:
``256``.
sampler (Callable, optional): Sampler from log probabilities.
Default: argmax (greedy).
logits_processors (List[Callable], optional): Optional logits
processors. Default: ``None``.
prompt_cache (List[Any], optional): A pre-computed prompt cache.
Updated in place if provided.
prefill_step_size (int): Step size for processing the prompt.
kv_bits, kv_group_size, quantized_kv_start: KV-cache quantisation
options (see :func:`generate_step`).

Yields:
Tuple[mx.array, mx.array, bool]: One token id, the vector of log
probabilities, and a bool that is ``True`` when the token was
drafted from the prompt and accepted (the analogue of
``from_draft`` in :func:`speculative_generate_step`).
"""
if prompt_lookup_num_tokens < 1:
raise ValueError("prompt_lookup_num_tokens must be >= 1")
if prompt_lookup_min_match < 1:
raise ValueError("prompt_lookup_min_match must be >= 1")

y = prompt.astype(mx.uint32)
prompt_ids = prompt.tolist() if hasattr(prompt, "tolist") else list(prompt)
generated: List[int] = []
prev_tokens: Optional[mx.array] = None

if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
else:
model_cache = prompt_cache

sampler = sampler or (lambda x: mx.argmax(x, axis=-1))

quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)

def _process_and_sample(tokens, logits):
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
return sampler(logprobs), logprobs

def _step(y_in, n_predict=1):
with mx.stream(generation_stream):
logits = model(y_in[None], cache=model_cache)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(model_cache)
if logits_processors:
nonlocal prev_tokens
out_y, out_logprobs = [], []
y_seen = y_in[: -(n_predict - 1)] if n_predict > 1 else y_in
for i in range(n_predict):
prev_tokens = (
mx.concatenate([prev_tokens, y_seen])
if prev_tokens is not None
else y_seen
)
tok, lp = _process_and_sample(prev_tokens, logits[:, i, :])
out_y.append(tok)
out_logprobs.append(lp)
return mx.concatenate(out_y, axis=0), mx.concatenate(
out_logprobs, axis=0
)
else:
return _process_and_sample(None, logits.squeeze(0))

def _prefill(y_in):
while y_in.size > 1:
n_to_process = min(prefill_step_size, y_in.size - 1)
model(y_in[:n_to_process][None], cache=model_cache)
quantize_cache_fn(model_cache)
mx.eval([c.state for c in model_cache])
y_in = y_in[n_to_process:]
mx.clear_cache()
return y_in

with mx.stream(generation_stream):
y = _prefill(y)

ntoks = 0
try:
while ntoks < max_tokens:
# Look up a draft in the prompt using the recent generation tail.
# If we have not generated enough tokens yet (or the recent tail
# has no n-gram match in the prompt), fall back to plain AR.
draft = _pld_find_draft(
generated,
prompt_ids,
k_lookback=prompt_lookup_min_match,
k_lookahead=min(prompt_lookup_num_tokens, max_tokens - ntoks - 1),
)

if not draft:
# AR fallback: single-token forward.
tok, lp = _step(y, n_predict=1)
mx.eval(tok)
tok_id = tok.item()
generated.append(tok_id)
ntoks += 1
yield tok_id, lp, False
if ntoks == max_tokens:
break
y = tok
continue

# Verify [y, draft_1, ..., draft_k] in one forward. Snapshot the
# pre-verify state of every cache so we can roll back on partial
# accept; the snapshot is by-reference and is dropped on full
# accept.
k = len(draft)
pre_snap = cache.snapshot_prompt_cache(model_cache)
verify_in = mx.concatenate([y, mx.array(draft, mx.uint32)])
tokens, logprobs = _step(verify_in, n_predict=k + 1)
mx.eval(tokens)
tokens_list = tokens.tolist()

# Count accepted draft tokens (longest matching prefix).
n_accept = 0
for i in range(k):
if tokens_list[i] != draft[i]:
break
n_accept += 1

# Yield accepted drafts.
for i in range(n_accept):
if ntoks == max_tokens:
break
generated.append(tokens_list[i])
ntoks += 1
yield tokens_list[i], logprobs[i], True

if ntoks < max_tokens:
# Yield the bonus / correction token from the verify forward.
bonus = tokens_list[n_accept]
generated.append(bonus)
ntoks += 1
yield bonus, logprobs[n_accept], False

if ntoks == max_tokens:
break

if n_accept == k:
# Full accept: cache already at the correct post-verify
# state, drop the snapshot.
y = mx.array([tokens_list[k]], mx.uint32)
else:
# Partial accept: restore both caches to pre-verify state
# and re-forward the accepted prefix so the cache lands at
# exactly "after y + accepted drafts". This is bit-exact
# for both KV and recurrent caches and costs one extra
# forward of size (n_accept + 1) tokens.
cache.restore_prompt_cache(model_cache, pre_snap)
accepted = mx.array(draft[:n_accept], mx.uint32)
reforward = mx.concatenate([y, accepted])
with mx.stream(generation_stream):
model(reforward[None], cache=model_cache)
quantize_cache_fn(model_cache)
# Trim any logits-processor history that overshot the
# accepted range (the verify pulled (k+1) tokens through
# the processor; we only accepted (n_accept + 1)).
if prev_tokens is not None and logits_processors:
overshoot = k - n_accept
if overshoot > 0:
prev_tokens = prev_tokens[:-overshoot]
y = mx.array([tokens_list[n_accept]], mx.uint32)
finally:
# Nothing to roll back on early exit: every yielded token's cache
# state was either left as-is (full accept) or already restored +
# re-forwarded (partial accept) before the next iteration.
pass


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 256,
draft_model: Optional[nn.Module] = None,
prompt_lookup_num_tokens: Optional[int] = None,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
Expand All @@ -675,6 +923,10 @@ def stream_generate(
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
prompt_lookup_num_tokens (Optional[int]): If set, use Prompt Lookup
Decoding (PLD) with this many drafted tokens per cycle. PLD drafts
from n-gram matches in the prompt -- no draft model is required.
Mutually exclusive with ``draft_model``. Default: ``None``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.

Expand All @@ -698,13 +950,28 @@ def stream_generate(

kwargs["max_tokens"] = max_tokens

if draft_model is None:
if draft_model is not None and prompt_lookup_num_tokens is not None:
raise ValueError(
"draft_model and prompt_lookup_num_tokens are mutually exclusive."
)
if draft_model is None and prompt_lookup_num_tokens is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
elif prompt_lookup_num_tokens is not None:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
kwargs.pop("num_draft_tokens", None)
kwargs.pop("input_embeddings", None)
token_generator = prompt_lookup_generate_step(
prompt,
model,
prompt_lookup_num_tokens=prompt_lookup_num_tokens,
**kwargs,
)
else:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
Expand Down Expand Up @@ -2083,6 +2350,7 @@ def main():
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
prompt_lookup_num_tokens=args.prompt_lookup_num_tokens,
)
if not args.verbose:
print(response)
Expand Down
Loading
Loading