diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 18633cbd7..6d142506c 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -4,6 +4,7 @@ on: push: branches: ["main"] pull_request: + workflow_dispatch: permissions: contents: read @@ -14,7 +15,7 @@ concurrency: jobs: check_lint: - if: github.repository == 'ml-explore/mlx-lm' + if: github.repository == 'ml-explore/mlx-lm' || github.repository == 'benjamin-levin/mlx-lm' runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v5 @@ -24,7 +25,7 @@ jobs: - uses: pre-commit/action@v3.0.1 mac_build_and_test: - if: github.repository == 'ml-explore/mlx-lm' + if: github.repository == 'ml-explore/mlx-lm' || github.repository == 'benjamin-levin/mlx-lm' runs-on: [self-hosted, macos] needs: check_lint steps: diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..f86381f95 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -219,6 +219,26 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--auto-speculative", + action="store_true", + help=( + "Opt-in auto-router that picks between prompt-lookup (PLD) and " + "plain AR based on prompt length, n-gram density, and a short " + "PLD probe. No draft model required. Mutually exclusive with " + "--draft-model." + ), + ) + parser.add_argument( + "--prompt-lookup-num-tokens", + type=int, + default=None, + help=( + "When set (and --draft-model is unset), use prompt-lookup " + "speculative decoding directly with this many draft tokens per " + "cycle. Ignored if --auto-speculative is set." + ), + ) return parser @@ -654,6 +674,479 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +# --------------------------------------------------------------------------- +# Prompt-lookup decoding (PLD) and auto-speculative routing +# +# PLD drafts continuation tokens by searching the prompt for an n-gram suffix +# of the tokens generated so far. When the model is reproducing prompt +# content (RAG, code edit, translation, summarization) this gives 4-8 +# verified tokens per main-model forward without needing a draft model. +# +# Auto-speculative routing pairs PLD with a cheap probe: a length-based +# pre-filter skips short prompts that PLD can't help (drafts there miss +# constantly and waste a verify-forward per miss); on long prompts a +# 16-token PLD probe measures actual acceptance, and the router commits to +# PLD when acceptance clears a threshold, or falls back to plain AR +# (reusing the warm cache, so the probe cost is paid once). +# --------------------------------------------------------------------------- + + +# Routing constants (conservative defaults; the PLD probe is cheap so we'd +# rather over-test than under-test). +_AUTO_SPEC_SHORT_LEN = 256 # below: skip PLD entirely +_AUTO_SPEC_LONG_LEN = 1024 # at/above: PLD is a strong candidate +_AUTO_SPEC_PROBE_TOKENS = 16 # tokens to generate in the probe +_AUTO_SPEC_PROBE_THRESHOLD = 0.30 # min acceptance rate to commit to PLD +_AUTO_SPEC_PROBE_EARLY_BAIL = 4 # consecutive PLD misses to abort probe + + +def _pld_find_draft( + generated: List[int], + prompt_ids: Sequence[int], + k_lookback: int, + k_lookahead: int, +) -> List[int]: + """Find the longest matching n-gram suffix of ``generated`` in + ``prompt_ids`` and return up to ``k_lookahead`` tokens that follow it. + Returns ``[]`` when no match is found or when there are no continuation + tokens. + + Searches for n-gram lengths from ``min(k_lookback, len(generated))`` + down to 1, returning on the first hit (longer matches first → higher + acceptance odds). + """ + if not generated or not prompt_ids: + return [] + max_n = min(k_lookback, len(generated)) + for n in range(max_n, 0, -1): + suffix = tuple(generated[-n:]) + # Scan prompt for last occurrence of suffix (most-recent match is + # most likely to share continuation with the current generation + # step). + for i in range(len(prompt_ids) - n, -1, -1): + if tuple(prompt_ids[i : i + n]) == suffix: + start = i + n + end = min(start + k_lookahead, len(prompt_ids)) + if end > start: + return list(prompt_ids[start:end]) + break # match exists but no follow-on tokens + return [] + + +def _auto_spec_score(prompt_ids: Sequence[int]) -> float: + """Cheap heuristic in [0, 1] for whether PLD is worth probing. + + Combines a length factor (PLD needs a long prompt to search) with a + bigram-repetition factor (lists/code/structured docs give better hits). + """ + n = len(prompt_ids) + if n < _AUTO_SPEC_SHORT_LEN: + return 0.0 + if n >= _AUTO_SPEC_LONG_LEN: + len_factor = 1.0 + else: + len_factor = (n - _AUTO_SPEC_SHORT_LEN) / float( + _AUTO_SPEC_LONG_LEN - _AUTO_SPEC_SHORT_LEN + ) + + # Bigram density on the last 4k tokens — keeps scoring O(min(n, 4096)) + # even for 100k-token prompts. + window = prompt_ids[-4096:] + if len(window) < 2: + return 0.0 + bigrams = [(window[i], window[i + 1]) for i in range(len(window) - 1)] + unique = len(set(bigrams)) + repeat_factor = 1.0 - (unique / len(bigrams)) + # Map [0.05, 0.40] → [0, 1]; prose ≈ 0.05, structured ≈ 0.25. + repeat_factor = max(0.0, min(1.0, (repeat_factor - 0.05) / 0.35)) + return max(0.0, min(1.0, 0.7 * len_factor + 0.3 * repeat_factor)) + + +def prompt_lookup_generate_step( + prompt: mx.array, + model: nn.Module, + prompt_ids: Sequence[int], + *, + prompt_lookup_num_tokens: int = 8, + prompt_lookup_max_matches: int = 2, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """ + Prompt-lookup speculative decoding: draft tokens by n-gram lookup + against the prompt, verify in one main-model forward, accept the greedy + prefix. + + No draft model required — drafts come from the prompt itself. + + Args: + prompt (mx.array): The input prompt tokens. + model (nn.Module): The model to use for generation. + prompt_ids (Sequence[int]): The prompt as a Python list (used for + the n-gram search; avoids per-step ``.tolist()`` conversions). + prompt_lookup_num_tokens (int): Number of draft tokens to look + ahead per cycle (``k_lookahead``). Default: ``8``. + prompt_lookup_max_matches (int): Maximum n-gram length to search + for (``k_lookback``). Default: ``2``. + max_tokens, sampler, logits_processors, prompt_cache, + prefill_step_size, kv_bits, kv_group_size, quantized_kv_start: + Same semantics as :func:`speculative_generate_step`. + + Yields: + Tuple[mx.array, mx.array, bool]: One token, a vector of log + probabilities, and a bool indicating if the token was contributed + by a prompt-lookup draft (analogous to ``from_draft`` from + :func:`speculative_generate_step`). + """ + y = prompt.astype(mx.uint32) + prev_tokens = None + + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model) + + if not cache.can_trim_prompt_cache(prompt_cache): + types = {type(c).__name__ for c in prompt_cache if not c.is_trimmable()} + raise ValueError( + "Prompt-lookup decoding requires a trimmable prompt cache " + f"(got {types})." + ) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + return sampler(logprobs), logprobs + + def _step(y_in, n_predict=1): + nonlocal prev_tokens + with mx.stream(generation_stream): + logits = model(y_in[None], cache=prompt_cache) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(prompt_cache) + if logits_processors: + out_y, out_lp = [], [] + if n_predict > 1: + y_seen = y_in[: -(n_predict - 1)] + else: + y_seen = y_in + for i in range(n_predict): + prev_tokens = ( + mx.concatenate([prev_tokens, y_seen]) + if prev_tokens is not None + else y_seen + ) + y_seen, lp = _process_and_sample(prev_tokens, logits[:, i, :]) + out_y.append(y_seen) + out_lp.append(lp) + return ( + mx.concatenate(out_y, axis=0), + mx.concatenate(out_lp, axis=0), + ) + return _process_and_sample(None, logits.squeeze(0)) + + def _prefill(y_in): + while y_in.size > 1: + n_to_process = min(prefill_step_size, y_in.size - 1) + model(y_in[:n_to_process][None], cache=prompt_cache) + quantize_cache_fn(prompt_cache) + mx.eval([c.state for c in prompt_cache]) + y_in = y_in[n_to_process:] + mx.clear_cache() + return y_in + + with mx.stream(generation_stream): + y = _prefill(y) + + generated_for_lookup: List[int] = [] + ntoks = 0 + n = 0 + num_draft = 0 + try: + while ntoks < max_tokens: + # Find n-gram draft from the prompt. ``y`` is a single-token + # array carrying the last accepted token forward into the next + # verify call. + if generated_for_lookup: + remaining = max_tokens - ntoks - 1 + draft_tokens_py = ( + _pld_find_draft( + generated_for_lookup, + prompt_ids, + prompt_lookup_max_matches, + min(prompt_lookup_num_tokens, max(remaining, 0)), + ) + if remaining > 0 + else [] + ) + else: + draft_tokens_py = [] + + if draft_tokens_py: + draft_arr = mx.array(draft_tokens_py, mx.uint32) + num_draft = draft_arr.size + y_verify = mx.concatenate([y, draft_arr]) + else: + num_draft = 0 + y_verify = y + + if prev_tokens is not None and num_draft > 0: + prev_tokens = prev_tokens[ + : prev_tokens.size - y.size - num_draft + 1 + ] + + tokens, logprobs = _step(y_verify, num_draft + 1) + mx.eval(tokens) + tokens_py = tokens.tolist() + if isinstance(tokens_py, int): + tokens_py = [tokens_py] + + n = 0 + while n < num_draft: + if tokens_py[n] != draft_tokens_py[n]: + break + n += 1 + ntoks += 1 + generated_for_lookup.append(tokens_py[n - 1]) + yield ( + mx.array(tokens_py[n - 1], mx.uint32), + logprobs[n - 1], + True, + ) + if ntoks == max_tokens: + break + + if ntoks < max_tokens: + ntoks += 1 + generated_for_lookup.append(tokens_py[n]) + yield ( + mx.array(tokens_py[n], mx.uint32), + logprobs[n], + False, + ) + + if ntoks == max_tokens: + break + + y = mx.array([tokens_py[n]], mx.uint32) + if prev_tokens is not None and num_draft > 0: + prev_tokens = prev_tokens[: -max(num_draft - n, 1)] + # Rewind cache by the number of unaccepted drafted tokens. + if num_draft > 0: + cache.trim_prompt_cache(prompt_cache, num_draft - n) + finally: + if num_draft > 0: + cache.trim_prompt_cache(prompt_cache, num_draft - n) + + +def auto_speculative_generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, + prompt_lookup_num_tokens: int = 8, + prompt_lookup_max_matches: int = 2, + auto_spec_probe_tokens: int = _AUTO_SPEC_PROBE_TOKENS, + auto_spec_probe_threshold: float = _AUTO_SPEC_PROBE_THRESHOLD, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """ + Auto-router that picks between prompt-lookup (PLD) and plain AR based on + prompt characteristics and a short PLD probe. + + Routing: + * ``len(prompt) < 256`` → AR directly (PLD has nothing to match). + * Otherwise: prefill once, run a 16-token PLD probe. If the probe's + draft-acceptance rate clears ``auto_spec_probe_threshold``, continue + with PLD using the warm cache; otherwise continue with AR using the + warm cache. Either way the probe cost is paid once. + + Yields ``(token, logprobs, from_draft)`` triples; ``from_draft`` is + True when the token was contributed by a successful PLD draft. + """ + prompt_ids: List[int] = ( + prompt.tolist() if isinstance(prompt, mx.array) else list(prompt) + ) + n_prompt = len(prompt_ids) + + # Layer 1: short prompts skip PLD entirely. + if n_prompt < _AUTO_SPEC_SHORT_LEN or _auto_spec_score(prompt_ids) <= 0.0: + for tok, lp in generate_step( + prompt, + model, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ): + yield tok, lp, False + return + + # Layer 2: probe via PLD on a shared cache. + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model) + can_trim = cache.can_trim_prompt_cache(prompt_cache) + + # If the cache isn't trimmable PLD can't run; fall back to AR with the + # same cache so we don't waste the prefill. + if not can_trim: + for tok, lp in generate_step( + prompt, + model, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ): + yield tok, lp, False + return + + pld_gen = prompt_lookup_generate_step( + prompt, + model, + prompt_ids, + prompt_lookup_num_tokens=prompt_lookup_num_tokens, + prompt_lookup_max_matches=prompt_lookup_max_matches, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ) + + probe_budget = min(auto_spec_probe_tokens, max_tokens) + probe_emitted: List[Tuple[mx.array, mx.array, bool]] = [] + n_drafted = 0 + n_accepted = 0 + consecutive_misses = 0 + early_bail = False + last_was_draft = False + + while len(probe_emitted) < probe_budget: + try: + tok, lp, from_draft = next(pld_gen) + except StopIteration: + break + probe_emitted.append((tok, lp, from_draft)) + if from_draft: + n_accepted += 1 + n_drafted += 1 + consecutive_misses = 0 + last_was_draft = True + else: + # ``from_draft=False`` immediately after a streak of + # ``from_draft=True`` is the post-verify confirm token, not a + # miss. Treat a miss as a standalone False with no preceding + # True in this PLD cycle. + if last_was_draft: + last_was_draft = False + else: + n_drafted += 1 # verify-without-draft cycle (miss) + consecutive_misses += 1 + if consecutive_misses >= _AUTO_SPEC_PROBE_EARLY_BAIL: + early_bail = True + break + + accept_rate = n_accepted / max(1, n_drafted) + keep_pld = (not early_bail) and accept_rate >= auto_spec_probe_threshold + + # Drain probe tokens to the caller. + for triple in probe_emitted: + yield triple + + if len(probe_emitted) >= max_tokens: + try: + pld_gen.close() + except Exception: + pass + return + + if keep_pld: + # Continue with the live PLD generator — cache + draft state are + # already warm. + for tok, lp, from_draft in pld_gen: + yield tok, lp, from_draft + return + + # Probe failed: close PLD and continue AR from the warm cache. The + # cache was updated in place by the PLD generator and reflects + # (prompt + len(probe_emitted)) tokens; AR continues from the last + # emitted token. + try: + pld_gen.close() + except Exception: + pass + + if not probe_emitted: + for tok, lp in generate_step( + prompt, + model, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ): + yield tok, lp, False + return + + last_tok = probe_emitted[-1][0] + if isinstance(last_tok, mx.array): + last_tok_arr = last_tok if last_tok.ndim == 1 else last_tok[None] + else: + last_tok_arr = mx.array([last_tok], mx.uint32) + + remaining = max_tokens - len(probe_emitted) + for tok, lp in generate_step( + last_tok_arr, + model, + max_tokens=remaining, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + ): + yield tok, lp, False + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], @@ -698,19 +1191,68 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + auto_speculative = kwargs.pop("auto_speculative", False) + pld_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) + pld_max_matches = kwargs.pop("prompt_lookup_max_matches", 2) + auto_probe_tokens = kwargs.pop( + "auto_spec_probe_tokens", _AUTO_SPEC_PROBE_TOKENS + ) + auto_probe_threshold = kwargs.pop( + "auto_spec_probe_threshold", _AUTO_SPEC_PROBE_THRESHOLD + ) + + if draft_model is not None: + if auto_speculative or pld_num_tokens is not None: + raise ValueError( + "auto_speculative / prompt_lookup_num_tokens are mutually " + "exclusive with draft_model speculative decoding." + ) + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + elif auto_speculative: + # Opt-in auto router: length + n-gram score pre-filter, then a + # 16-token PLD probe, then commit to PLD or fall back to AR using + # the warm cache. + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + token_generator = auto_speculative_generate_step( + prompt, + model, + prompt_lookup_num_tokens=( + pld_num_tokens if pld_num_tokens is not None else 8 + ), + prompt_lookup_max_matches=pld_max_matches, + auto_spec_probe_tokens=auto_probe_tokens, + auto_spec_probe_threshold=auto_probe_threshold, + **kwargs, + ) + elif pld_num_tokens is not None: + # Direct PLD path: no router, just n-gram speculative decoding. + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + prompt_ids = ( + prompt.tolist() if isinstance(prompt, mx.array) else list(prompt) + ) + token_generator = prompt_lookup_generate_step( + prompt, + model, + prompt_ids, + prompt_lookup_num_tokens=pld_num_tokens, + prompt_lookup_max_matches=pld_max_matches, + **kwargs, + ) + else: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) - else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): @@ -2069,6 +2611,15 @@ def main(): xtc_threshold=args.xtc_threshold, xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids), ) + extra_kwargs = {} + if args.auto_speculative: + if draft_model is not None: + raise ValueError( + "--auto-speculative is mutually exclusive with --draft-model." + ) + extra_kwargs["auto_speculative"] = True + if args.prompt_lookup_num_tokens is not None: + extra_kwargs["prompt_lookup_num_tokens"] = args.prompt_lookup_num_tokens response = generate( model, tokenizer, @@ -2083,6 +2634,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + **extra_kwargs, ) if not args.verbose: print(response)