diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..fb54053d5 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -3,10 +3,11 @@ import argparse import contextlib import copy -import functools import json +import math import sys import time +import warnings from collections import deque from dataclasses import dataclass from functools import partial @@ -38,7 +39,7 @@ TokenBuffer, load_prompt_cache, ) -from .sample_utils import make_sampler +from .sample_utils import categorical_sampling, make_sampler, make_sampler_chain from .tokenizer_utils import TokenizerWrapper from .utils import does_model_support_input_embeddings, load @@ -54,6 +55,7 @@ DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 +_CACHE_CLEAR_INTERVAL = 256 def str2bool(string): @@ -219,6 +221,12 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) return parser @@ -376,7 +384,7 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - quantize_cache_fn = functools.partial( + quantize_cache_fn = partial( maybe_quantize_kv_cache, quantized_kv_start=quantized_kv_start, kv_group_size=kv_group_size, @@ -464,7 +472,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None): if n == max_tokens: break yield y.item(), logprobs - if n % 256 == 0: + if n % _CACHE_CLEAR_INTERVAL == 0: mx.clear_cache() y, logprobs = next_y, next_logprobs n += 1 @@ -534,7 +542,7 @@ def speculative_generate_step( sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - quantize_cache_fn = functools.partial( + quantize_cache_fn = partial( maybe_quantize_kv_cache, quantized_kv_start=quantized_kv_start, kv_group_size=kv_group_size, @@ -654,12 +662,342 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +def mtp_generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 2048, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, + input_embeddings: Optional[mx.array] = None, + temp: float = 0.0, + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """A generator that uses the model's native MTP head for speculative decoding. + + Each iteration runs one backbone forward pass over the current token and its + pending draft, then one MTP forward pass to propose the next draft. Up to 2 + tokens are emitted per backbone step: one always-accepted backbone token and + one conditionally-accepted draft token. + + The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and + support ``return_hidden=True`` in its ``__call__``. + + Yields: + Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft). + ``from_draft`` is ``True`` when the token came from the MTP head. + """ + y = prompt.astype(mx.uint32) + prev_tokens = None + + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + mtp_cache = model.make_mtp_cache() + else: + # Split a pre-built cache at backbone length. If MTP entries are + # absent (e.g. cache created by make_prompt_cache), create them. + n_main = len(model.layers) + model_cache = prompt_cache[:n_main] + mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() + + _is_greedy = temp == 0 + + _filter_chain, _xtc_cell = ( + make_sampler_chain( + top_p, + top_k, + min_p, + min_tokens_to_keep, + xtc_probability, + xtc_threshold, + xtc_special_tokens, + ) + if not _is_greedy + else ([], None) + ) + + quantize_cache_fn = partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits, xtc_draw=None): + if logits_processors: + logits = logits[None] + for processor in logits_processors: + logits = processor(tokens, logits) + logits = logits.squeeze(0) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + if _filter_chain: + if _xtc_cell is not None: + _xtc_cell[0] = xtc_draw # None = fresh draw; mx.array = shared draw + masked = logprobs + for f in _filter_chain: + masked = f(masked) + token = categorical_sampling(masked, temp) + # lp_accept must reflect the same filtered distribution as token. + scaled = masked / temp + lp_accept = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) + elif _is_greedy: + token = mx.argmax(logprobs, axis=-1) + lp_accept = logprobs + else: + token = categorical_sampling(logprobs, temp) + scaled = logprobs / temp + lp_accept = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) + return token, logprobs, lp_accept + + def _clear_rollback(): + for c in model_cache: + if hasattr(c, "rollback_state"): + c.rollback_state = None + + def _rollback_draft(): + """Restore caches to the state after the confirmed token. + + SSM layers (ArraysCache): restore the conv/ssm snapshot saved by + GatedDeltaNet after the confirmed token. + Attention layers (KVCache): trim the draft-token entry. + """ + for c in model_cache: + if hasattr(c, "rollback_state") and c.rollback_state is not None: + conv_snap, ssm_snap = c.rollback_state + c[0] = conv_snap + c[1] = ssm_snap + c.rollback_state = None + elif c.is_trimmable(): + c.trim(1) + + def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0, xtc_draw=None): + """Run the backbone on ``y`` and return (tokens, logprobs, accept_lps, hidden, prev_tokens).""" + with mx.stream(generation_stream): + logits, hidden = model( + y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed + ) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(model_cache) + toks, lps, accept_lps = [], [], [] + for i in range(n_predict): + if logits_processors: + prev_tokens = ( + mx.concatenate([prev_tokens, y[i : i + 1]]) + if prev_tokens is not None + else y[i : i + 1] + ) + # Pass the shared XTC draw only for position 0 (the verify position). + draw = xtc_draw if i == 0 else None + tok, lp, alp = _process_and_sample( + prev_tokens, logits[:, i, :].squeeze(0), draw + ) + toks.append(tok) + lps.append(lp) + accept_lps.append(alp) + return ( + mx.stack(toks), + mx.stack(lps), + mx.stack(accept_lps), + hidden, + prev_tokens, + ) + + def _step_mtp(hidden_last, main_tok, prev_tokens, *, cache_commit=None): + """Run the MTP head and return (draft_token, draft_logprobs, draft_accept_lp, xtc_draw). + + cache_commit: (hidden, tok) prepended as a cache-alignment position so that the + accepted draft token is committed to mtp_cache in the same batched forward. + """ + if cache_commit is not None: + align_h, align_tok = cache_commit + hidden_last = mx.concatenate([align_h, hidden_last], axis=1) + next_ids = mx.concatenate( + [align_tok.reshape(1, 1), main_tok.reshape(1, 1)], axis=1 + ) + else: + next_ids = main_tok.reshape(1, 1) + with mx.stream(generation_stream): + mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) + quantize_cache_fn(mtp_cache) + mtp_logits = mtp_logits[:, -1, :].squeeze(0) + if logits_processors: + tokens_for_proc = ( + mx.concatenate([prev_tokens, main_tok.reshape(-1)]) + if prev_tokens is not None + else main_tok.reshape(-1) + ) + else: + tokens_for_proc = prev_tokens + # Draw the XTC boolean once here so the verify step can reuse it. + xtc_draw = mx.random.uniform() if _xtc_cell is not None else None + draft_tok, draft_lp, draft_accept_lp = _process_and_sample( + tokens_for_proc, mtp_logits, xtc_draw + ) + return draft_tok, draft_lp, draft_accept_lp, xtc_draw + + def _prefill(y, input_embeddings): + # Leave exactly 1 token for _step_backbone so the decode loop starts clean. + total = len(input_embeddings) if input_embeddings is not None else y.size + while total > 1: + n = min(prefill_step_size, total - 1) + if input_embeddings is not None: + _, hidden = model( + y[:n][None], + cache=model_cache, + return_hidden=True, + input_embeddings=input_embeddings[:n][None], + ) + input_embeddings = input_embeddings[n:] + else: + _, hidden = model(y[:n][None], cache=model_cache, return_hidden=True) + model.mtp_forward(hidden, y[1 : n + 1][None], mtp_cache) + quantize_cache_fn(mtp_cache) + quantize_cache_fn(model_cache) + mx.eval([c.state for c in model_cache + mtp_cache if hasattr(c, "state")]) + y = y[n:] + total -= n + mx.clear_cache() + return y + + with mx.stream(generation_stream): + y = _prefill(y, input_embeddings) + + ntoks = 0 + last_cache_block = 0 + draft_tok = draft_lp = draft_accept_lp = draft_xtc_draw = None + + while ntoks < max_tokens: + if draft_tok is None: + # No pending draft: run backbone only, then generate first draft. + toks, lps, accept_lps, hidden, prev_tokens = _step_backbone( + y, prev_tokens, n_predict=1 + ) + mx.eval(toks) + main_tok, main_lp = toks[0], lps[0] + ntoks += 1 + yield main_tok.item(), main_lp, False + if ntoks >= max_tokens: + return + hidden_at_main = hidden[:, -1:, :] + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( + hidden_at_main, main_tok, prev_tokens + ) + mx.eval(draft_tok) + y = mx.array([main_tok.item()], mx.uint32) + else: + # Verify draft: run backbone over [y, draft_tok]. + # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state + # after the confirmed token y, enabling exact rollback on rejection. + y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) + toks, lps, accept_lps, hidden, prev_tokens = _step_backbone( + y_with_draft, + prev_tokens, + n_predict=2, + n_confirmed=1, + xtc_draw=draft_xtc_draw, + ) + u = mx.random.uniform() + mx.eval(toks, draft_tok, u) + + verify_pred, bonus_tok = toks[0], toks[1] + verify_lp, bonus_lp = lps[0], lps[1] + verify_accept_lp = accept_lps[0] + draft_tok_id = draft_tok.item() + + if _is_greedy: + accept = verify_pred.item() == draft_tok_id + else: + # Probabilistic acceptance: min(1, p_target/p_draft) with temp-adjusted logprobs. + log_accept = ( + verify_accept_lp[draft_tok_id] - draft_accept_lp[draft_tok_id] + ).item() + accept = log_accept >= 0 or u.item() < math.exp(log_accept) + + hidden_at_confirmed = hidden[:, 0:1, :] + hidden_at_draft = hidden[:, 1:2, :] + + if accept: + _clear_rollback() + ntoks += 1 + yield draft_tok_id, draft_lp, True + if ntoks >= max_tokens: + return + ntoks += 1 + yield bonus_tok.item(), bonus_lp, False + if ntoks >= max_tokens: + return + # Next draft: one batched forward aligns the cache for the + # accepted draft token and generates the next draft together. + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( + hidden_at_draft, + bonus_tok, + prev_tokens, + cache_commit=(hidden_at_confirmed, draft_tok), + ) + mx.eval(draft_tok) + y = mx.array([bonus_tok.item()], mx.uint32) + else: + _rollback_draft() + if logits_processors and prev_tokens is not None: + prev_tokens = prev_tokens[:-1] # discard rejected draft token + verify_tok_id = verify_pred.item() + if not _is_greedy: + # Sample from residual distribution max(p_target - p_draft, 0) / Z + # (Leviathan et al. 2022 ยง2.3; Chen et al. 2023). Guarantees the + # output marginal equals the target distribution exactly. + # Both distributions are temperature-adjusted to match sampling. + p_target = mx.exp(verify_accept_lp) + p_draft = mx.exp(draft_accept_lp) + residual = mx.maximum(p_target - p_draft, 0.0) + z = residual.sum(keepdims=True) + dist = mx.where(z > 0, residual, p_target) + # categorical treats -inf log-prob as p=0. + verify_tok_id = mx.random.categorical( + mx.log(dist).reshape(1, -1) + ).item() + ntoks += 1 + yield verify_tok_id, verify_lp, False + if ntoks >= max_tokens: + return + # Next draft from MTP at y's hidden state. + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( + hidden_at_confirmed, + mx.array([verify_tok_id], mx.uint32), + prev_tokens, + ) + mx.eval(draft_tok) + y = mx.array([verify_tok_id], mx.uint32) + block = ntoks // _CACHE_CLEAR_INTERVAL + if block > last_cache_block: + mx.clear_cache() + last_cache_block = block + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + mtp: bool = False, + temp: float = 0.0, + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -675,6 +1013,8 @@ def stream_generate( draft_model (Optional[nn.Module]): An optional draft model. If provided then speculative decoding is used. The draft model must use the same tokenizer as the main model. Default: ``None``. + mtp (bool): Use native Multi-Token Prediction for speculative + decoding. Requires a model with an MTP head. Default: ``False``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -698,19 +1038,43 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + if draft_model is not None: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + elif mtp and hasattr(model, "mtp_forward"): + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + kwargs.pop("sampler", None) # mtp_generate_step does not accept sampler= + token_generator = mtp_generate_step( + prompt, + model, + temp=temp, + top_p=top_p, + top_k=top_k, + min_p=min_p, + min_tokens_to_keep=min_tokens_to_keep, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + xtc_special_tokens=xtc_special_tokens, + **kwargs, + ) + else: + if mtp: + warnings.warn( + "--mtp flag ignored: model does not have an MTP head. " + "Falling back to standard generation.", + stacklevel=2, + ) kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) - else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): @@ -2083,6 +2447,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + mtp=args.mtp, ) if not args.verbose: print(response) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..340079f2d 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -592,6 +592,10 @@ def nbytes(self): class ArraysCache(_BaseCache): + # Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens + # in an MTP draft-verification step. Cleared after each step. + rollback_state: Optional[tuple] = None + def __new__(cls, *args, **kwargs): instance = super().__new__(cls) instance.left_padding = None diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index a86710e74..289989db8 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -50,6 +50,9 @@ class TextModelArgs(BaseModelArgs): moe_intermediate_size: int = 0 norm_topk_prob: bool = True + # MTP fields + mtp_num_hidden_layers: int = 0 + # Rope parameters rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( default_factory=lambda: { @@ -129,11 +132,59 @@ def __init__(self, config: TextModelArgs): self.sharding_group = None + def _process_chunk( + self, + qkv_chunk: mx.array, + a_chunk: mx.array, + b_chunk: mx.array, + conv_state: mx.array, + ssm_state: Optional[mx.array], + ssm_mask: Optional[mx.array] = None, + lengths: Optional[mx.array] = None, + ): + B, S_chunk = qkv_chunk.shape[:2] + conv_in = mx.concatenate([conv_state, qkv_chunk], axis=1) + n_keep = self.conv_kernel_size - 1 + if lengths is not None: + ends = mx.clip(lengths, 0, S_chunk) + positions = (ends[:, None] + mx.arange(n_keep))[..., None] + new_conv_state = mx.take_along_axis(conv_in, positions, axis=1) + else: + new_conv_state = mx.contiguous(conv_in[:, -n_keep:]) + conv_out = nn.silu(self.conv1d(conv_in)) + + q, k, v = [ + t.reshape(B, S_chunk, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, new_ssm_state = gated_delta_update( + q, + k, + v, + a_chunk, + b_chunk, + self.A_log, + self.dt_bias, + ssm_state, + ssm_mask, + use_kernel=not self.training, + ) + return out, new_conv_state, new_ssm_state + def __call__( self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: B, S, _ = inputs.shape @@ -152,49 +203,44 @@ def __call__( (B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype, ) + ssm_state = cache[1] if cache else None if mask is not None: qkv = mx.where(mask[..., None], qkv, 0) - conv_input = mx.concatenate([conv_state, qkv], axis=1) - if cache is not None: - n_keep = self.conv_kernel_size - 1 - if cache.lengths is not None: - ends = mx.clip(cache.lengths, 0, S) - positions = (ends[:, None] + mx.arange(n_keep))[..., None] - cache[0] = mx.take_along_axis(conv_input, positions, axis=1) - else: - cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) - conv_out = nn.silu(self.conv1d(conv_input)) - q, k, v = [ - t.reshape(B, S, h, d) - for t, h, d in zip( - mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), - [self.num_k_heads, self.num_k_heads, self.num_v_heads], - [self.head_k_dim, self.head_k_dim, self.head_v_dim], + if n_confirmed > 0 and n_confirmed < S: + # Process confirmed and draft tokens separately so we can snapshot the + # SSM/conv state between them for exact rollback on draft rejection. + mask_c = mask[:, :n_confirmed] if mask is not None else None + mask_d = mask[:, n_confirmed:] if mask is not None else None + out_c, conv_c, ssm_c = self._process_chunk( + qkv[:, :n_confirmed], + a[:, :n_confirmed], + b[:, :n_confirmed], + conv_state, + ssm_state, + mask_c, + ) + if cache is not None: + cache.rollback_state = (conv_c, ssm_c) + out_d, conv_f, ssm_f = self._process_chunk( + qkv[:, n_confirmed:], + a[:, n_confirmed:], + b[:, n_confirmed:], + conv_c, + ssm_c, + mask_d, + ) + out = mx.concatenate([out_c, out_d], axis=1) + else: + lengths = cache.lengths if cache is not None else None + out, conv_f, ssm_f = self._process_chunk( + qkv, a, b, conv_state, ssm_state, mask, lengths=lengths ) - ] - - state = cache[1] if cache else None - inv_scale = k.shape[-1] ** -0.5 - q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) - k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) - - out, state = gated_delta_update( - q, - k, - v, - a, - b, - self.A_log, - self.dt_bias, - state, - mask, - use_kernel=not self.training, - ) if cache is not None: - cache[1] = state + cache[0] = conv_f + cache[1] = ssm_f cache.advance(S) out = self.norm(out, z) @@ -230,9 +276,12 @@ def __call__( x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: if self.is_linear: - r = self.linear_attn(self.input_layernorm(x), mask, cache) + r = self.linear_attn( + self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed + ) else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -240,6 +289,69 @@ def __call__( return out +class MTPDecoderLayer(nn.Module): + """Full-attention-only transformer layer for the MTP head (no GatedDeltaNet).""" + + def __init__(self, args: TextModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + if args.num_experts > 0: + self.mlp = SparseMoeBlock(args) + else: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + return h + self.mlp(self.post_attention_layernorm(h)) + + +class MTPModule(nn.Module): + """Multi-Token Prediction head (Qwen3.5 native speculative decoding). + + Predicts token t+2 from the backbone hidden state h_t and the sampled + token t+1, using a shared lm_head with the backbone. + """ + + def __init__(self, args: TextModelArgs): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers)] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + embed_tokens: nn.Embedding, + cache: Optional[Any] = None, + ) -> mx.array: + embeds = embed_tokens(next_token_ids) # (B, 1, H) + e = self.pre_fc_norm_embedding(embeds) + h = self.pre_fc_norm_hidden(hidden_states) + fused = self.fc(mx.concatenate([e, h], axis=-1)) # (B, 1, H) + + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(fused, cache[0]) + for layer, c in zip(self.layers, cache): + fused = layer(fused, mask, c) + + return self.norm(fused) # (B, 1, H) + + class Qwen3_5TextModel(nn.Module): def __init__(self, args: TextModelArgs): super().__init__() @@ -256,6 +368,7 @@ def __call__( inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + n_confirmed: int = 0, ) -> mx.array: if input_embeddings is not None: hidden_states = input_embeddings @@ -270,9 +383,11 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - hidden_states = layer(hidden_states, mask=mask, cache=c) + hidden_states = layer( + hidden_states, mask=mask, cache=c, n_confirmed=n_confirmed + ) - return self.norm(hidden_states) + return hidden_states class TextModel(nn.Module): @@ -283,20 +398,55 @@ def __init__(self, args: TextModelArgs): self.model = Qwen3_5TextModel(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if args.mtp_num_hidden_layers > 0: + self.mtp = MTPModule(args) def __call__( self, inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, + n_confirmed: int = 0, ) -> mx.array: - out = self.model(inputs, cache, input_embeddings=input_embeddings) + hidden = self.model( + inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed + ) + normed = self.model.norm(hidden) if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) + out = self.model.embed_tokens.as_linear(normed) else: - out = self.lm_head(out) + out = self.lm_head(normed) + if return_hidden: + return out, hidden # pre-norm hidden for MTP head return out + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Run the MTP head and apply the shared lm_head. + + Args: + hidden_states: Backbone pre-norm hidden state (B, N, H). N=1 during decode, N>1 during prompt prefill. + next_token_ids: Next token ids, shape (B, N). + mtp_cache: KVCache entries for the MTP transformer layers. + + Returns: + logits of shape (B, N, vocab_size). + """ + mtp_out = self.mtp( + hidden_states, + next_token_ids, + self.model.embed_tokens, + mtp_cache, + ) + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(mtp_out) + return self.lm_head(mtp_out) + @property def layers(self): return self.model.layers @@ -304,13 +454,28 @@ def layers(self): def make_cache(self): return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + def make_mtp_cache(self): + """Return a fresh list of KVCache entries for the MTP layer(s).""" + if hasattr(self, "mtp"): + return [KVCache() for _ in self.mtp.layers] + return [] + def sanitize(self, weights): - has_mtp_weights = any("mtp." in k for k in weights) has_unsanitized_conv1d = any( "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() ) - should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d - weights = {k: v for k, v in weights.items() if "mtp." not in k} + # Norm weights need a +1 shift only in raw HF checkpoints (detected via + # unsanitized conv1d). Already-converted MLX models must not be shifted + # again, even when they contain MTP weights. + should_shift_norm_weights = has_unsanitized_conv1d + # Keep MTP weights if this model has an MTP head; drop them otherwise + if not hasattr(self, "mtp"): + weights = {k: v for k, v in weights.items() if "mtp." not in k} + elif not any("mtp." in k for k in weights): + raise ValueError( + "Config specifies mtp_num_hidden_layers > 0 but the model weights " + "contain no MTP parameters. Set mtp_num_hidden_layers=0 to disable MTP." + ) if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) @@ -321,6 +486,10 @@ def sanitize(self, weights): "model.norm.weight", ".q_norm.weight", ".k_norm.weight", + # MTP-specific norms (not covered by the patterns above) + ".pre_fc_norm_hidden.weight", + ".pre_fc_norm_embedding.weight", + "mtp.norm.weight", ) for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: @@ -332,14 +501,13 @@ def sanitize(self, weights): @property def quant_predicate(self): - if self.args.num_experts <= 0: - return None - def predicate(path, _): if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): return {"group_size": 64, "bits": 8} return True + if self.args.num_experts <= 0 and self.args.mtp_num_hidden_layers <= 0: + return None return predicate @property @@ -376,9 +544,15 @@ def __call__( inputs: mx.array, cache=None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, + n_confirmed: int = 0, ): return self.language_model( - inputs, cache=cache, input_embeddings=input_embeddings + inputs, + cache=cache, + input_embeddings=input_embeddings, + return_hidden=return_hidden, + n_confirmed=n_confirmed, ) def sanitize(self, weights): @@ -515,6 +689,19 @@ def _repeat(p): layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group ) + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Delegate to language_model.mtp_forward. See TextModel.mtp_forward.""" + return self.language_model.mtp_forward(hidden_states, next_token_ids, mtp_cache) + + def make_mtp_cache(self): + """Return fresh KVCache entries for the MTP layer(s).""" + return self.language_model.make_mtp_cache() + @property def layers(self): return self.language_model.model.layers diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 53ab8530e..98fe47a19 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -2,6 +2,8 @@ from dataclasses import dataclass +import mlx.core as mx + from .base import BaseModelArgs from .qwen3_5 import Model as Qwen3_5Model @@ -18,6 +20,31 @@ def from_dict(cls, params): return super().from_dict(params) +def _unfuse_experts(weights, prefix): + """Split fused gate_up_proj into per-projection switch_mlp weights (Qwen3.6 format).""" + gate_up_key = f"{prefix}.experts.gate_up_proj" + if gate_up_key not in weights: + return + gate_up = weights.pop(gate_up_key) + mid = gate_up.shape[-2] // 2 + weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[..., :mid, :] + weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[..., mid:, :] + weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( + f"{prefix}.experts.down_proj" + ) + + +def _stack_per_expert(weights, prefix, num_experts): + """Stack per-expert weights into switch_mlp format (Qwen3.5 format).""" + for n in ("gate_proj", "up_proj", "down_proj"): + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(num_experts) + ] + ) + + class Model(Qwen3_5Model): def sanitize(self, weights): @@ -27,26 +54,27 @@ def sanitize(self, weights): continue if key.startswith("model.language_model"): key = key.replace("model.language_model", "language_model.model") - elif key.startswith("language_model."): - pass - else: + elif not key.startswith("language_model."): key = "language_model." + key new_weights[key] = value + # Backbone MoE layers always use fused gate_up_proj (both Qwen3.5 and Qwen3.6). for l in range(self.language_model.args.num_hidden_layers): - prefix = f"language_model.model.layers.{l}.mlp" - gate_up_key = f"{prefix}.experts.gate_up_proj" - if gate_up_key in new_weights: - gate_up = new_weights.pop(gate_up_key) - mid = gate_up.shape[-2] // 2 - new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid, : - ] - new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid:, : - ] - new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( - f"{prefix}.experts.down_proj" - ) + _unfuse_experts(new_weights, f"language_model.model.layers.{l}.mlp") + + # MTP layers: fused format (Qwen3.6) or per-expert format (Qwen3.5). + # Detect format once from the first layer and apply uniformly. + mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0) + if mtp_num > 0: + num_experts = self.language_model.args.num_experts + mtp_is_fused = ( + "language_model.mtp.layers.0.mlp.experts.gate_up_proj" in new_weights + ) + for layer_idx in range(mtp_num): + prefix = f"language_model.mtp.layers.{layer_idx}.mlp" + if mtp_is_fused: + _unfuse_experts(new_weights, prefix) + else: + _stack_per_expert(new_weights, prefix, num_experts) return self.language_model.sanitize(new_weights) diff --git a/mlx_lm/sample_utils.py b/mlx_lm/sample_utils.py index 05a45fc60..dd6f18d5a 100644 --- a/mlx_lm/sample_utils.py +++ b/mlx_lm/sample_utils.py @@ -2,11 +2,42 @@ import math from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import mlx.core as mx +def make_sampler_chain( + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], +) -> Tuple[List[Callable[[mx.array], mx.array]], Optional[List]]: + """Return (filter_chain, xtc_cell) for use in mtp_generate_step. + + xtc_cell is a mutable [uniform_draw] slot; set it before running the chain + to share the XTC boolean across draft and verify steps. None when XTC is off. + """ + _xtc_cell: Optional[List] = [None] if xtc_probability > 0.0 else None + chain: List[Callable[[mx.array], mx.array]] = [] + if top_p > 0 and top_p < 1.0: + chain.append(lambda x: apply_top_p(x, top_p)) + if min_p != 0.0: + chain.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) + if xtc_probability > 0.0: + chain.append( + lambda x: apply_xtc( + x, xtc_probability, xtc_threshold, xtc_special_tokens, _xtc_cell[0] + ) + ) + if top_k > 0: + chain.append(lambda x: apply_top_k(x, top_k)) + return chain, _xtc_cell + + def make_sampler( temp: float = 0.0, top_p: float = 0.0, @@ -47,17 +78,15 @@ def make_sampler( return lambda x: mx.argmax(x, axis=-1) # Create sampler chain - sampling_methods = [] - if top_p > 0 and top_p < 1.0: - sampling_methods.append(lambda x: apply_top_p(x, top_p)) - if min_p != 0.0: - sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) - if xtc_probability > 0.0: - sampling_methods.append( - lambda x: apply_xtc(x, xtc_probability, xtc_threshold, xtc_special_tokens) - ) - if top_k > 0: - sampling_methods.append(lambda x: apply_top_k(x, top_k)) + sampling_methods, _ = make_sampler_chain( + top_p, + top_k, + min_p, + min_tokens_to_keep, + xtc_probability, + xtc_threshold, + xtc_special_tokens, + ) # Apply the sampling methods def sampler(logprobs): @@ -243,6 +272,7 @@ def apply_xtc( xtc_probability: float, xtc_threshold: float, xtc_special_tokens: List[int], + p_draw: Optional[mx.array] = None, ) -> mx.array: """ Apply XTC sampling to the logits. @@ -252,6 +282,9 @@ def apply_xtc( xtc_probability (float): Probability of XTC sampling to happen for each token xtc_threshold (float): The threshold the probs need to reach for being sampled. special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling. + p_draw (mx.array, optional): Pre-drawn uniform; if None, draws fresh. + Pass the same draw to draft and verify steps to share the XTC + apply/skip decision across both forward passes. """ if not (0 <= xtc_threshold <= 0.5): raise ValueError( @@ -267,8 +300,9 @@ def apply_xtc( if xtc_special_tokens: mask[..., xtc_special_tokens] = False + draw = mx.random.uniform(0, 1) if p_draw is None else p_draw return mx.where( - mx.random.uniform(0, 1) > xtc_probability, + draw > xtc_probability, logits, mx.where(mask, -mx.inf, logits), ) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..b1d6214ee 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -396,6 +396,10 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): return self.model, self.tokenizer +def _xtc_special_tokens(tokenizer): + return tokenizer.encode("\n") + list(tokenizer.eos_token_ids) + + def _make_sampler(args, tokenizer): return make_sampler( args.sampling.temperature, @@ -404,10 +408,7 @@ def _make_sampler(args, tokenizer): min_p=args.sampling.min_p, xtc_probability=args.sampling.xtc_probability, xtc_threshold=args.sampling.xtc_threshold, - xtc_special_tokens=[ - tokenizer.eos_token_id, - tokenizer.encode("\n"), - ], + xtc_special_tokens=_xtc_special_tokens(tokenizer), ) @@ -810,7 +811,14 @@ def get_next_request(timeout=None): rqueue.put(e) continue - if not self._is_batchable(args): + # Prefer single-sequence MTP when the queue is empty; + # fall back to BatchGenerator when requests are queued. + mtp_active = getattr(self.cli_args, "mtp", False) and hasattr( + model, "mtp_forward" + ) + if not self._is_batchable(args) or ( + mtp_active and self.requests.empty() + ): self._serve_single((rqueue, request, args)) continue @@ -985,6 +993,14 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + mtp=getattr(self.cli_args, "mtp", False), + temp=args.sampling.temperature, + top_p=args.sampling.top_p, + top_k=args.sampling.top_k, + min_p=args.sampling.min_p, + xtc_probability=args.sampling.xtc_probability, + xtc_threshold=args.sampling.xtc_threshold, + xtc_special_tokens=_xtc_special_tokens(tokenizer), ): finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) @@ -1884,6 +1900,12 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] diff --git a/tests/test_mtp.py b/tests/test_mtp.py new file mode 100644 index 000000000..9996dddf9 --- /dev/null +++ b/tests/test_mtp.py @@ -0,0 +1,291 @@ +import importlib +import unittest + +import mlx.core as mx + +from mlx_lm.generate import generate_step, mtp_generate_step +from mlx_lm.models.cache import make_prompt_cache + + +def _make_qwen3_5_mtp_model(): + """Create a tiny Qwen3.5 model with an MTP head for testing.""" + module = importlib.import_module("mlx_lm.models.qwen3_5") + args = module.ModelArgs.from_dict( + { + "model_type": "qwen3_5", + "text_config": { + "model_type": "qwen3_5", + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "vocab_size": 256, + "linear_num_value_heads": 2, + "linear_num_key_heads": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_conv_kernel_dim": 3, + "full_attention_interval": 2, + "tie_word_embeddings": True, + "rms_norm_eps": 1e-5, + "head_dim": 32, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 128, + "mtp_num_hidden_layers": 1, + }, + } + ) + model = module.Model(args) + model.set_dtype(mx.float32) + mx.eval(model.parameters()) + return model + + +class TestMTP(unittest.TestCase): + """Tests for native MTP (Multi-Token Prediction) speculative decoding. + + Uses a tiny synthetic Qwen3.5 model (4 layers, hidden=64, vocab=256) + with mtp_num_hidden_layers=1 and full_attention_interval=2, giving a + mix of GatedDeltaNet (SSM) and full-attention layers. + + Not tested here (would require a real tokenizer loaded from HF): + - stream_generate() with mtp=True/False flag dispatch + - Server integration (--mtp flag, is_batchable) + """ + + @classmethod + def setUpClass(cls): + cls.model = _make_qwen3_5_mtp_model() + + def test_mtp_module_exists(self): + """Model with mtp_num_hidden_layers=1 should have MTP head.""" + self.assertTrue(hasattr(self.model, "mtp_forward")) + self.assertTrue(hasattr(self.model, "make_mtp_cache")) + lm = self.model.language_model + self.assertTrue(hasattr(lm, "mtp")) + self.assertEqual(len(lm.mtp.layers), 1) + + def test_make_mtp_cache(self): + """make_mtp_cache should return one KVCache per MTP layer.""" + mtp_cache = self.model.make_mtp_cache() + self.assertEqual(len(mtp_cache), 1) + self.assertTrue(mtp_cache[0].is_trimmable()) + + def test_return_hidden(self): + """return_hidden=True should return (logits, hidden) with correct shapes.""" + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + out, hidden = self.model(inputs, cache=cache, return_hidden=True) + self.assertEqual(out.shape, (1, 3, 256)) + self.assertEqual(hidden.shape, (1, 3, 64)) + + def test_mtp_forward_shape(self): + """mtp_forward should produce logits of shape (B, 1, vocab).""" + hidden = mx.random.normal((1, 1, 64)) + next_ids = mx.array([[5]]) + mtp_cache = self.model.make_mtp_cache() + logits = self.model.mtp_forward(hidden, next_ids, mtp_cache) + self.assertEqual(logits.shape, (1, 1, 256)) + + def test_hidden_is_pre_norm(self): + """Hidden states returned with return_hidden should be pre-norm. + + This verifies the fix for double normalization: the backbone returns + pre-norm hidden states, and the final norm is applied only before + lm_head (not before the MTP head). + """ + lm = self.model.language_model + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + + _, hidden = lm(inputs, cache=cache, return_hidden=True) + + # Apply the final norm manually and check it changes the values. + normed = lm.model.norm(hidden) + self.assertFalse(mx.allclose(hidden, normed, atol=1e-5).item()) + + def test_mtp_generate_identity(self): + """mtp_generate_step should produce the same greedy tokens as generate_step. + + This is the most important correctness test: it proves that the + draft/verify loop, SSM state rollback on rejection, and MTP cache + management are all correct. Any bug in these would cause the MTP + path to diverge from standard generation. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + # Standard generation, greedy (default sampler is argmax). + std_cache = make_prompt_cache(self.model) + std_tokens = [] + for i, (tok, _) in enumerate( + generate_step(prompt, self.model, prompt_cache=std_cache) + ): + std_tokens.append(int(tok)) + if i + 1 >= n_tokens: + break + + # MTP generation, greedy (sampler=None uses exact-match acceptance). + mtp_tokens = [] + for tok, _, _ in mtp_generate_step(prompt, self.model, max_tokens=n_tokens): + mtp_tokens.append(int(tok)) + if len(mtp_tokens) >= n_tokens: + break + + self.assertEqual( + std_tokens, + mtp_tokens, + f"Token mismatch: std={std_tokens}, mtp={mtp_tokens}", + ) + + def test_mtp_probabilistic_acceptance_completes(self): + """mtp_generate_step should complete without errors with a stochastic sampler. + + Exercises the probabilistic acceptance path: min(1, p_target / p_draft), + both with bare temp (no filters) and with top_k applied. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + for kwargs in [ + {"temp": 0.7}, + {"temp": 0.7, "top_k": 10}, + ]: + tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, max_tokens=n_tokens, **kwargs + ): + tokens.append(int(tok)) + if len(tokens) >= n_tokens: + break + self.assertEqual(len(tokens), n_tokens, f"kwargs={kwargs}") + + def test_mtp_generate_identity_with_logits_processor(self): + """mtp_generate_step must produce the same greedy tokens as generate_step + when a context-sensitive stateless processor is applied. + + A processor that boosts (tokens[-1] + 1) % vocab biases sampling based on + the last token. Incorrect prev_tokens management in the verify pass would + cause the bonus token or the token after a rejection to be sampled with + the wrong bias, producing a sequence that diverges from serial generation. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def context_processor(tokens, logits): + if tokens is None or tokens.size == 0: + return logits + target = (int(tokens[-1].item()) + 1) % logits.shape[-1] + # 1D boost broadcasts correctly for both (vocab,) and (1, vocab) logits. + boost = mx.zeros(logits.shape[-1]) + return logits + boost.at[target].add(10.0) + + std_cache = make_prompt_cache(self.model) + std_tokens = [] + for i, (tok, _) in enumerate( + generate_step( + prompt, + self.model, + prompt_cache=std_cache, + logits_processors=[context_processor], + ) + ): + std_tokens.append(int(tok)) + if i + 1 >= n_tokens: + break + + mtp_tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, + self.model, + max_tokens=n_tokens, + logits_processors=[context_processor], + ): + mtp_tokens.append(int(tok)) + if len(mtp_tokens) >= n_tokens: + break + + self.assertEqual(std_tokens, mtp_tokens) + + def test_mtp_processor_prev_tokens_correct_at_draft_step(self): + """The processor must see the just-sampled backbone token as tokens[-1] + when the MTP head runs, not the preceding input token. + + A forcing processor logs tokens[-1] on every call. When tokens[-1] equals + the last prompt token (3) it applies a large boost to token 4, guaranteeing + the backbone samples token 4 regardless of model weights. The second + processor call comes from the MTP head: if the token context is correct it + sees 4; if stale it sees 3 again. + """ + # Last prompt token is 3; the forcing processor boosts token 4 when it + # sees 3, so the backbone deterministically samples T0 = 4 regardless of weights. + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + + logged: list[int] = [] + + def forcing_processor(tokens, logits): + if tokens is not None and tokens.size > 0: + last = int(tokens[-1].item()) + logged.append(last) + if last == 3: + boost = mx.zeros(logits.shape[-1]) + return logits + boost.at[4].add(1000.0) + return logits + + for _tok, _, _ in mtp_generate_step( + prompt, + self.model, + max_tokens=2, + logits_processors=[forcing_processor], + ): + pass + + # First call (backbone): context is the last prompt token. + self.assertGreaterEqual(len(logged), 2) + self.assertEqual(logged[0], 3) + # Second call (MTP head): context must be T0 = 4, not the prompt token. + self.assertEqual(logged[1], 4) + + def _collect_rejection_tokens(self, n_runs=60, **kwargs): + """Run mtp_generate_step n_runs times with max_tokens=1 and return all + rejection tokens (from_draft=False).""" + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + rejection_tokens: list[int] = [] + for _ in range(n_runs): + for tok, _, from_draft in mtp_generate_step( + prompt, self.model, max_tokens=1, **kwargs + ): + if not from_draft: + rejection_tokens.append(int(tok)) + return rejection_tokens + + def _assert_residual_varies(self, rejection_tokens, label=""): + self.assertGreaterEqual( + len(rejection_tokens), + 5, + f"{label}Too few rejection events observed; increase n_runs", + ) + self.assertGreater( + len(set(rejection_tokens)), + 1, + f"{label}Rejection tokens are always identical, argmax bug likely present", + ) + + def test_mtp_rejection_residual_sampling(self): + """On rejection at temp>0, the emitted token must be sampled from the + residual distribution max(p_target - p_draft, 0) / Z, not the backbone + argmax. The argmax is deterministic for a fixed model state, so rejection + tokens would always be identical. Residual sampling produces a + distribution, so tokens must vary across runs. + + Using max_tokens=1 yields exactly one token per run: the accepted draft + (from_draft=True) or the rejection token (from_draft=False). This avoids + conflating rejection tokens with bonus tokens (also from_draft=False). + """ + self._assert_residual_varies(self._collect_rejection_tokens(temp=1.0)) + + +if __name__ == "__main__": + unittest.main()