From d203eed5ba7ad96b65d1979bbf5e3a39a8567be4 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 18:03:46 +0100 Subject: [PATCH 01/29] feat: native MTP speculative decoding for Qwen3.5 Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights. Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable). --- mlx_lm/generate.py | 196 +++++++++++++++++++++++++++++++++++++-- mlx_lm/models/qwen3_5.py | 153 ++++++++++++++++++++++++++++-- 2 files changed, 335 insertions(+), 14 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..47a8c8a3f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -654,6 +654,183 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +def mtp_generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """A generator using the model's native MTP head for speculative decoding. + + Produces up to 2 tokens per forward pass: + - 1 backbone token (always accepted) + - 1 MTP draft token (accepted if the backbone agrees on the next step) + + The model must expose ``mtp_forward(hidden, next_tok, mtp_cache)`` and + support ``return_hidden=True`` in its ``__call__``. + + Yields: + Tuple[mx.array, mx.array, bool]: token, log-probabilities, from_draft. + """ + y = prompt.astype(mx.uint32) + prev_tokens = None + + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + mtp_cache = model.make_mtp_cache() + else: + # When a pre-built cache is provided, split at backbone length + n_main = len(model.layers) + model_cache = prompt_cache[:n_main] + mtp_cache = prompt_cache[n_main:] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + tok = sampler(logprobs) + return tok, logprobs + + def _step_backbone(y, n_predict=1): + """One backbone forward pass. Returns (tokens, logprobs, hidden).""" + with mx.stream(generation_stream): + logits, hidden = model(y[None], cache=model_cache, return_hidden=True) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(model_cache) + nonlocal prev_tokens + toks, lps = [], [] + y_ctx = y if n_predict == 1 else y[: -(n_predict - 1)] + for i in range(n_predict): + if logits_processors: + prev_tokens = ( + mx.concatenate([prev_tokens, y_ctx]) + if prev_tokens is not None + else y_ctx + ) + tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) + toks.append(tok) + lps.append(lp) + return mx.stack(toks), mx.stack(lps), hidden + + def _step_mtp(hidden_last, main_tok): + """Run MTP head. Returns (draft_token, draft_logprobs).""" + # hidden_last: (1, 1, H), main_tok: 0-d or scalar + next_ids = main_tok.reshape(1, 1) + with mx.stream(generation_stream): + mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) + quantize_cache_fn(mtp_cache) + mtp_logits = mtp_logits[:, -1, :].squeeze(0) + draft_tok, draft_lp = _process_and_sample(prev_tokens, mtp_logits) + return draft_tok, draft_lp + + def _prefill(y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=model_cache) + quantize_cache_fn(model_cache) + mx.eval([c.state for c in model_cache if hasattr(c, "state")]) + y = y[prefill_step_size:] + mx.clear_cache() + return y + + with mx.stream(generation_stream): + y = _prefill(y) + + ntoks = 0 + draft_tok = None + draft_lp = None + + try: + while True: + if draft_tok is None: + # No pending draft — run backbone only, then generate first draft + toks, lps, hidden = _step_backbone(y, n_predict=1) + mx.eval(toks) + main_tok = toks[0] + main_lp = lps[0] + + ntoks += 1 + yield main_tok, main_lp, False + if ntoks >= max_tokens: + break + + draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok) + mx.eval(draft_tok) + y = mx.array([main_tok.item()], mx.uint32) + else: + # Verify draft: process [y, draft_tok] through backbone together + y_with_draft = mx.concatenate( + [y, mx.array([draft_tok.item()], mx.uint32)] + ) + toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2) + mx.eval(toks, draft_tok) + + verify_pred = toks[0] # backbone prediction after y → verify draft + bonus_tok = toks[1] # backbone prediction after draft_tok + verify_lp = lps[0] + bonus_lp = lps[1] + + if verify_pred.item() == draft_tok.item(): + # Draft accepted + ntoks += 1 + yield draft_tok, draft_lp, True + if ntoks >= max_tokens: + break + + ntoks += 1 + yield bonus_tok, bonus_lp, False + if ntoks >= max_tokens: + break + + # Next draft from MTP at draft_tok's hidden state + draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok) + mx.eval(draft_tok) + y = mx.array([bonus_tok.item()], mx.uint32) + else: + # Draft rejected — trim caches. + # + # Qwen3.5 is a hybrid SSM+Attention model: attention layers use + # KVCache (trimmable), SSM layers use ArraysCache (not trimmable). + # trim_prompt_cache() is all-or-nothing, so we trim KV entries + # individually. The SSM state will retain a 1-token contamination + # from the rejected draft, which is empirically negligible compared + # to the sequence length but means output may differ slightly from + # standard generate_step. A correct fix would require exposing + # per-token intermediate SSM states from GatedDeltaNet (future work). + for c in model_cache: + if c.is_trimmable(): + c.trim(1) + cache.trim_prompt_cache(mtp_cache, 1) + + ntoks += 1 + yield verify_pred, verify_lp, False + if ntoks >= max_tokens: + break + + # Next draft from MTP at y's hidden state + draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred) + mx.eval(draft_tok) + y = mx.array([verify_pred.item()], mx.uint32) + finally: + pass + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], @@ -698,19 +875,24 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + if draft_model is not None: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + elif hasattr(model, "mtp_forward"): + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + token_generator = mtp_generate_step(prompt, model, **kwargs) + else: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) - else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index a86710e74..0a48656dd 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -50,6 +50,10 @@ class TextModelArgs(BaseModelArgs): moe_intermediate_size: int = 0 norm_topk_prob: bool = True + # MTP fields + mtp_num_hidden_layers: int = 0 + mtp_use_dedicated_embeddings: bool = False + # Rope parameters rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( default_factory=lambda: { @@ -240,6 +244,79 @@ def __call__( return out +class MTPDecoderLayer(nn.Module): + """Full-attention-only transformer layer for the MTP head (no GatedDeltaNet).""" + + def __init__(self, args: TextModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + if args.num_experts > 0: + self.mlp = SparseMoeBlock(args) + else: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + return h + self.mlp(self.post_attention_layernorm(h)) + + +class MTPModule(nn.Module): + """Multi-Token Prediction head. + + Predicts the token at position t+2 given: + - h_t : backbone hidden state at the last accepted position t + - t_main: the main model's sampled prediction for t+1 + + Forward: + fused = fc(cat([pre_fc_norm_embedding(embed(t_main)), + pre_fc_norm_hidden(h_t)])) + → MTPDecoderLayer(s) + → norm + → (caller applies lm_head, shared with backbone) + """ + + def __init__(self, args: TextModelArgs): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [ + MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + embed_tokens: nn.Embedding, + cache: Optional[Any] = None, + ) -> mx.array: + # hidden_states : (B, 1, H) — backbone hidden at last accepted position + # next_token_ids: (B, 1) — t_main (main model's prediction for t+1) + embeds = embed_tokens(next_token_ids) # (B, 1, H) + e = self.pre_fc_norm_embedding(embeds) + h = self.pre_fc_norm_hidden(hidden_states) + fused = self.fc(mx.concatenate([e, h], axis=-1)) # (B, 1, H) + + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(fused, cache[0]) + for layer, c in zip(self.layers, cache): + fused = layer(fused, mask, c) + + return self.norm(fused) # (B, 1, H) + + class Qwen3_5TextModel(nn.Module): def __init__(self, args: TextModelArgs): super().__init__() @@ -283,20 +360,51 @@ def __init__(self, args: TextModelArgs): self.model = Qwen3_5TextModel(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if args.mtp_num_hidden_layers > 0: + self.mtp = MTPModule(args) def __call__( self, inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, ) -> mx.array: - out = self.model(inputs, cache, input_embeddings=input_embeddings) + hidden = self.model(inputs, cache, input_embeddings=input_embeddings) if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) + out = self.model.embed_tokens.as_linear(hidden) else: - out = self.lm_head(out) + out = self.lm_head(hidden) + if return_hidden: + return out, hidden return out + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Run the MTP head and apply the shared lm_head. + + Args: + hidden_states: (B, 1, H) — backbone hidden state at the last position. + next_token_ids: (B, 1) — sampled main token (t_main). + mtp_cache: list of KVCache entries for the MTP transformer layer(s). + + Returns: + logits: (B, 1, vocab_size) + """ + mtp_out = self.mtp( + hidden_states, + next_token_ids, + self.model.embed_tokens, + mtp_cache, + ) + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(mtp_out) + return self.lm_head(mtp_out) + @property def layers(self): return self.model.layers @@ -304,13 +412,23 @@ def layers(self): def make_cache(self): return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + def make_mtp_cache(self): + """Return a fresh list of KVCache entries for the MTP layer(s).""" + if hasattr(self, "mtp"): + return [KVCache() for _ in self.mtp.layers] + return [] + def sanitize(self, weights): - has_mtp_weights = any("mtp." in k for k in weights) has_unsanitized_conv1d = any( "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() ) - should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d - weights = {k: v for k, v in weights.items() if "mtp." not in k} + # Norm weights need a +1 shift only in raw HF checkpoints (detected via + # unsanitized conv1d). Already-converted MLX models (conv1d fixed) must NOT + # be shifted again — even when they contain MTP weights. + should_shift_norm_weights = has_unsanitized_conv1d + # Keep MTP weights if this model has an MTP head; drop them otherwise + if not hasattr(self, "mtp"): + weights = {k: v for k, v in weights.items() if "mtp." not in k} if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) @@ -321,6 +439,10 @@ def sanitize(self, weights): "model.norm.weight", ".q_norm.weight", ".k_norm.weight", + # MTP-specific norms (not covered by the patterns above) + ".pre_fc_norm_hidden.weight", + ".pre_fc_norm_embedding.weight", + "mtp.norm.weight", ) for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: @@ -376,9 +498,13 @@ def __call__( inputs: mx.array, cache=None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, ): return self.language_model( - inputs, cache=cache, input_embeddings=input_embeddings + inputs, + cache=cache, + input_embeddings=input_embeddings, + return_hidden=return_hidden, ) def sanitize(self, weights): @@ -515,6 +641,19 @@ def _repeat(p): layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group ) + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Delegate to language_model.mtp_forward. See TextModel.mtp_forward.""" + return self.language_model.mtp_forward(hidden_states, next_token_ids, mtp_cache) + + def make_mtp_cache(self): + """Return fresh KVCache entries for the MTP layer(s).""" + return self.language_model.make_mtp_cache() + @property def layers(self): return self.language_model.model.layers From 651b9458b5e8e1a4490562a741963087b6426cce Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 18:27:29 +0100 Subject: [PATCH 02/29] fix(mtp): eliminate SSM state contamination on draft rejection Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls. After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state. On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before. Acceptance rate ~64% average / ~85% on 100-token run. --- mlx_lm/generate.py | 56 +++++++++------- mlx_lm/models/cache.py | 3 + mlx_lm/models/qwen3_5.py | 139 ++++++++++++++++++++++----------------- 3 files changed, 114 insertions(+), 84 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 47a8c8a3f..ecf332d0a 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -667,17 +667,19 @@ def mtp_generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: - """A generator using the model's native MTP head for speculative decoding. + """A generator that uses the model's native MTP head for speculative decoding. - Produces up to 2 tokens per forward pass: - - 1 backbone token (always accepted) - - 1 MTP draft token (accepted if the backbone agrees on the next step) + Each iteration runs one backbone forward pass over the current token and its + pending draft, then one MTP forward pass to propose the next draft. Up to 2 + tokens are emitted per backbone step: one always-accepted backbone token and + one conditionally-accepted draft token. - The model must expose ``mtp_forward(hidden, next_tok, mtp_cache)`` and + The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and support ``return_hidden=True`` in its ``__call__``. Yields: - Tuple[mx.array, mx.array, bool]: token, log-probabilities, from_draft. + Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft). + ``from_draft`` is ``True`` when the token came from the MTP head. """ y = prompt.astype(mx.uint32) prev_tokens = None @@ -708,10 +710,10 @@ def _process_and_sample(tokens, logits): tok = sampler(logprobs) return tok, logprobs - def _step_backbone(y, n_predict=1): - """One backbone forward pass. Returns (tokens, logprobs, hidden).""" + def _step_backbone(y, n_predict=1, n_confirmed=0): + """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" with mx.stream(generation_stream): - logits, hidden = model(y[None], cache=model_cache, return_hidden=True) + logits, hidden = model(y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) nonlocal prev_tokens @@ -730,8 +732,7 @@ def _step_backbone(y, n_predict=1): return mx.stack(toks), mx.stack(lps), hidden def _step_mtp(hidden_last, main_tok): - """Run MTP head. Returns (draft_token, draft_logprobs).""" - # hidden_last: (1, 1, H), main_tok: 0-d or scalar + """Run the MTP head and return (draft_token, draft_logprobs).""" next_ids = main_tok.reshape(1, 1) with mx.stream(generation_stream): mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) @@ -774,11 +775,13 @@ def _prefill(y): mx.eval(draft_tok) y = mx.array([main_tok.item()], mx.uint32) else: - # Verify draft: process [y, draft_tok] through backbone together + # Verify draft: run backbone over [y, draft_tok]. + # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state + # after the confirmed token y, enabling exact rollback on rejection. y_with_draft = mx.concatenate( [y, mx.array([draft_tok.item()], mx.uint32)] ) - toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2) + toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) mx.eval(toks, draft_tok) verify_pred = toks[0] # backbone prediction after y → verify draft @@ -787,7 +790,11 @@ def _prefill(y): bonus_lp = lps[1] if verify_pred.item() == draft_tok.item(): - # Draft accepted + # Draft accepted — discard rollback snapshots. + for c in model_cache: + if hasattr(c, "rollback_state"): + c.rollback_state = None + ntoks += 1 yield draft_tok, draft_lp, True if ntoks >= max_tokens: @@ -803,18 +810,17 @@ def _prefill(y): mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32) else: - # Draft rejected — trim caches. - # - # Qwen3.5 is a hybrid SSM+Attention model: attention layers use - # KVCache (trimmable), SSM layers use ArraysCache (not trimmable). - # trim_prompt_cache() is all-or-nothing, so we trim KV entries - # individually. The SSM state will retain a 1-token contamination - # from the rejected draft, which is empirically negligible compared - # to the sequence length but means output may differ slightly from - # standard generate_step. A correct fix would require exposing - # per-token intermediate SSM states from GatedDeltaNet (future work). + # Draft rejected — roll back all caches to the state after y. + # SSM layers (ArraysCache): restore the conv/ssm snapshot saved + # by GatedDeltaNet after the confirmed token. + # Attention layers (KVCache): trim the draft-token entry. for c in model_cache: - if c.is_trimmable(): + if hasattr(c, "rollback_state") and c.rollback_state is not None: + conv_snap, ssm_snap = c.rollback_state + c[0] = conv_snap + c[1] = ssm_snap + c.rollback_state = None + elif c.is_trimmable(): c.trim(1) cache.trim_prompt_cache(mtp_cache, 1) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..745f5ebfe 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -596,6 +596,9 @@ def __new__(cls, *args, **kwargs): instance = super().__new__(cls) instance.left_padding = None instance.lengths = None + # Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens + # in an MTP draft-verification step. Cleared after each step. + instance.rollback_state = None return instance def __init__(self, size, left_padding: Optional[List[int]] = None): diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 0a48656dd..d0919b54d 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -52,7 +52,6 @@ class TextModelArgs(BaseModelArgs): # MTP fields mtp_num_hidden_layers: int = 0 - mtp_use_dedicated_embeddings: bool = False # Rope parameters rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( @@ -133,11 +132,52 @@ def __init__(self, config: TextModelArgs): self.sharding_group = None + def _process_chunk( + self, + qkv_chunk: mx.array, + a_chunk: mx.array, + b_chunk: mx.array, + conv_state: mx.array, + ssm_state: Optional[mx.array], + ssm_mask: Optional[mx.array] = None, + lengths: Optional[mx.array] = None, + ): + B, S_chunk = qkv_chunk.shape[:2] + conv_in = mx.concatenate([conv_state, qkv_chunk], axis=1) + n_keep = self.conv_kernel_size - 1 + if lengths is not None: + ends = mx.clip(lengths, 0, S_chunk) + positions = (ends[:, None] + mx.arange(n_keep))[..., None] + new_conv_state = mx.take_along_axis(conv_in, positions, axis=1) + else: + new_conv_state = mx.contiguous(conv_in[:, -n_keep:]) + conv_out = nn.silu(self.conv1d(conv_in)) + + q, k, v = [ + t.reshape(B, S_chunk, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, new_ssm_state = gated_delta_update( + q, k, v, a_chunk, b_chunk, + self.A_log, self.dt_bias, ssm_state, ssm_mask, + use_kernel=not self.training, + ) + return out, new_conv_state, new_ssm_state + def __call__( self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: B, S, _ = inputs.shape @@ -149,56 +189,39 @@ def __call__( b = self.in_proj_b(inputs) a = self.in_proj_a(inputs) - if cache is not None and cache[0] is not None: - conv_state = cache[0] - else: - conv_state = mx.zeros( - (B, self.conv_kernel_size - 1, self.conv_dim), - dtype=inputs.dtype, - ) + conv_state = ( + cache[0] + if cache is not None and cache[0] is not None + else mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) + ) + ssm_state = cache[1] if cache else None if mask is not None: qkv = mx.where(mask[..., None], qkv, 0) - conv_input = mx.concatenate([conv_state, qkv], axis=1) - if cache is not None: - n_keep = self.conv_kernel_size - 1 - if cache.lengths is not None: - ends = mx.clip(cache.lengths, 0, S) - positions = (ends[:, None] + mx.arange(n_keep))[..., None] - cache[0] = mx.take_along_axis(conv_input, positions, axis=1) - else: - cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) - conv_out = nn.silu(self.conv1d(conv_input)) - q, k, v = [ - t.reshape(B, S, h, d) - for t, h, d in zip( - mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), - [self.num_k_heads, self.num_k_heads, self.num_v_heads], - [self.head_k_dim, self.head_k_dim, self.head_v_dim], + if n_confirmed > 0 and n_confirmed < S: + # Process confirmed and draft tokens separately so we can snapshot the + # SSM/conv state between them for exact rollback on draft rejection. + mask_c = mask[:, :n_confirmed] if mask is not None else None + mask_d = mask[:, n_confirmed:] if mask is not None else None + out_c, conv_c, ssm_c = self._process_chunk( + qkv[:, :n_confirmed], a[:, :n_confirmed], b[:, :n_confirmed], + conv_state, ssm_state, mask_c, ) - ] - - state = cache[1] if cache else None - inv_scale = k.shape[-1] ** -0.5 - q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) - k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) - - out, state = gated_delta_update( - q, - k, - v, - a, - b, - self.A_log, - self.dt_bias, - state, - mask, - use_kernel=not self.training, - ) + if cache is not None: + cache.rollback_state = (conv_c, ssm_c) + out_d, conv_f, ssm_f = self._process_chunk( + qkv[:, n_confirmed:], a[:, n_confirmed:], b[:, n_confirmed:], + conv_c, ssm_c, mask_d, + ) + out = mx.concatenate([out_c, out_d], axis=1) + else: + lengths = cache.lengths if cache is not None else None + out, conv_f, ssm_f = self._process_chunk(qkv, a, b, conv_state, ssm_state, mask, lengths=lengths) if cache is not None: - cache[1] = state + cache[0] = conv_f + cache[1] = ssm_f cache.advance(S) out = self.norm(out, z) @@ -234,9 +257,10 @@ def __call__( x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: if self.is_linear: - r = self.linear_attn(self.input_layernorm(x), mask, cache) + r = self.linear_attn(self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed) else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -269,18 +293,10 @@ def __call__( class MTPModule(nn.Module): - """Multi-Token Prediction head. - - Predicts the token at position t+2 given: - - h_t : backbone hidden state at the last accepted position t - - t_main: the main model's sampled prediction for t+1 - - Forward: - fused = fc(cat([pre_fc_norm_embedding(embed(t_main)), - pre_fc_norm_hidden(h_t)])) - → MTPDecoderLayer(s) - → norm - → (caller applies lm_head, shared with backbone) + """Multi-Token Prediction head (Qwen3.5 native speculative decoding). + + Predicts token t+2 from the backbone hidden state h_t and the sampled + token t+1, using a shared lm_head with the backbone. """ def __init__(self, args: TextModelArgs): @@ -333,6 +349,7 @@ def __call__( inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + n_confirmed: int = 0, ) -> mx.array: if input_embeddings is not None: hidden_states = input_embeddings @@ -347,7 +364,8 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - hidden_states = layer(hidden_states, mask=mask, cache=c) + kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} + hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) return self.norm(hidden_states) @@ -369,8 +387,9 @@ def __call__( cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, return_hidden: bool = False, + n_confirmed: int = 0, ) -> mx.array: - hidden = self.model(inputs, cache, input_embeddings=input_embeddings) + hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(hidden) else: @@ -499,12 +518,14 @@ def __call__( cache=None, input_embeddings: Optional[mx.array] = None, return_hidden: bool = False, + n_confirmed: int = 0, ): return self.language_model( inputs, cache=cache, input_embeddings=input_embeddings, return_hidden=return_hidden, + n_confirmed=n_confirmed, ) def sanitize(self, weights): From 43f4205b356a09e0ca7128719e51170de1718df5 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 21:50:57 +0100 Subject: [PATCH 03/29] fix(mtp): server integration (yield types, cache fallback, batching) - Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate) - Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache) - Disable batch generation for MTP models (draft/verify loop requires single-sequence processing) Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator --- mlx_lm/generate.py | 13 +++++++------ mlx_lm/server.py | 3 +++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ecf332d0a..5bc7aa226 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -688,10 +688,11 @@ def mtp_generate_step( model_cache = cache.make_prompt_cache(model) mtp_cache = model.make_mtp_cache() else: - # When a pre-built cache is provided, split at backbone length + # Split a pre-built cache at backbone length. If MTP entries are + # absent (e.g. cache created by make_prompt_cache), create them. n_main = len(model.layers) model_cache = prompt_cache[:n_main] - mtp_cache = prompt_cache[n_main:] + mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -767,7 +768,7 @@ def _prefill(y): main_lp = lps[0] ntoks += 1 - yield main_tok, main_lp, False + yield main_tok.item(), main_lp, False if ntoks >= max_tokens: break @@ -796,12 +797,12 @@ def _prefill(y): c.rollback_state = None ntoks += 1 - yield draft_tok, draft_lp, True + yield draft_tok.item(), draft_lp, True if ntoks >= max_tokens: break ntoks += 1 - yield bonus_tok, bonus_lp, False + yield bonus_tok.item(), bonus_lp, False if ntoks >= max_tokens: break @@ -825,7 +826,7 @@ def _prefill(y): cache.trim_prompt_cache(mtp_cache, 1) ntoks += 1 - yield verify_pred, verify_lp, False + yield verify_pred.item(), verify_lp, False if ntoks >= max_tokens: break diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..8a4d1217f 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -378,6 +378,9 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): self.model = model self.tokenizer = tokenizer self.draft_model = draft_model + # MTP speculative decoding requires single-sequence generation. + if hasattr(model, "mtp_forward"): + is_batchable = False self.is_batchable = is_batchable def load_default(self): From 7449a004614c2a70bccbaeaafa21749e403e5540 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 22:26:12 +0100 Subject: [PATCH 04/29] fix(mtp): address @janhilgard code review feedback (double-norm, quant_predicate) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden). - Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy). 27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x). --- mlx_lm/models/qwen3_5.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index d0919b54d..327078a0f 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -367,7 +367,7 @@ def __call__( kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) - return self.norm(hidden_states) + return hidden_states class TextModel(nn.Module): @@ -390,12 +390,13 @@ def __call__( n_confirmed: int = 0, ) -> mx.array: hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) + normed = self.model.norm(hidden) if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(hidden) + out = self.model.embed_tokens.as_linear(normed) else: - out = self.lm_head(hidden) + out = self.lm_head(normed) if return_hidden: - return out, hidden + return out, hidden # pre-norm hidden for MTP head return out def mtp_forward( @@ -473,14 +474,16 @@ def sanitize(self, weights): @property def quant_predicate(self): - if self.args.num_experts <= 0: - return None - def predicate(path, _): if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): return {"group_size": 64, "bits": 8} + # Keep the MTP fusion projection in full precision. + if path.endswith("mtp.fc"): + return False return True + if self.args.num_experts <= 0 and self.args.mtp_num_hidden_layers <= 0: + return None return predicate @property From 71011abec9a919c27c10e2ecfffad16acc0afd8e Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:07:52 +0100 Subject: [PATCH 05/29] feat(mtp): add --mtp CLI flag for generate and server Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding. MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional. --- mlx_lm/generate.py | 12 +++++++++++- mlx_lm/server.py | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 5bc7aa226..3932a6bdb 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -219,6 +219,12 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) return parser @@ -844,6 +850,7 @@ def stream_generate( prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + mtp: bool = False, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -859,6 +866,8 @@ def stream_generate( draft_model (Optional[nn.Module]): An optional draft model. If provided then speculative decoding is used. The draft model must use the same tokenizer as the main model. Default: ``None``. + mtp (bool): Use native Multi-Token Prediction for speculative + decoding. Requires a model with an MTP head. Default: ``False``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -888,7 +897,7 @@ def stream_generate( token_generator = speculative_generate_step( prompt, model, draft_model, **kwargs ) - elif hasattr(model, "mtp_forward"): + elif mtp and hasattr(model, "mtp_forward"): kwargs.pop("max_kv_size", None) kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) @@ -2272,6 +2281,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + mtp=args.mtp, ) if not args.verbose: print(response) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 8a4d1217f..8ebdab76f 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -378,8 +378,11 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): self.model = model self.tokenizer = tokenizer self.draft_model = draft_model - # MTP speculative decoding requires single-sequence generation. - if hasattr(model, "mtp_forward"): + # MTP speculative decoding uses single-sequence generation + # (draft/verify loop is incompatible with batch generation). + # TODO: dynamically switch between MTP (1 request) and + # BatchGenerator (>= 2 concurrent requests). + if self.cli_args.mtp and hasattr(model, "mtp_forward"): is_batchable = False self.is_batchable = is_batchable @@ -988,6 +991,7 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + mtp=getattr(self.cli_args, "mtp", False), ): finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) @@ -1887,6 +1891,12 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] From 44430ccca4b3fee94cee1890538a45b0003d4fbc Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:25:53 +0100 Subject: [PATCH 06/29] test(mtp): add unit tests for MTP speculative decoding 8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers. - MTP module instantiation and cache creation - return_hidden shape and pre-norm verification - mtp_forward output shape - quant_predicate excludes mtp.fc - Token identity: mtp_generate_step == generate_step (greedy) - End-to-end mtp_generate_step completion --- tests/test_mtp.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 tests/test_mtp.py diff --git a/tests/test_mtp.py b/tests/test_mtp.py new file mode 100644 index 000000000..ba1a0e92f --- /dev/null +++ b/tests/test_mtp.py @@ -0,0 +1,182 @@ +import importlib +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.generate import generate_step, mtp_generate_step + + +def _make_qwen3_5_mtp_model(): + """Create a tiny Qwen3.5 model with an MTP head for testing.""" + module = importlib.import_module("mlx_lm.models.qwen3_5") + args = module.ModelArgs.from_dict( + { + "model_type": "qwen3_5", + "text_config": { + "model_type": "qwen3_5", + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "vocab_size": 256, + "linear_num_value_heads": 2, + "linear_num_key_heads": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_conv_kernel_dim": 3, + "full_attention_interval": 2, + "tie_word_embeddings": True, + "rms_norm_eps": 1e-5, + "head_dim": 32, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 128, + "mtp_num_hidden_layers": 1, + }, + } + ) + model = module.Model(args) + model.set_dtype(mx.float32) + mx.eval(model.parameters()) + return model + + +class TestMTP(unittest.TestCase): + """Tests for native MTP (Multi-Token Prediction) speculative decoding. + + Uses a tiny synthetic Qwen3.5 model (4 layers, hidden=64, vocab=256) + with mtp_num_hidden_layers=1 and full_attention_interval=2, giving a + mix of GatedDeltaNet (SSM) and full-attention layers. + + Not tested here (would require a real tokenizer loaded from HF): + - stream_generate() with mtp=True/False flag dispatch + - Server integration (--mtp flag, is_batchable) + """ + + @classmethod + def setUpClass(cls): + cls.model = _make_qwen3_5_mtp_model() + + def test_mtp_module_exists(self): + """Model with mtp_num_hidden_layers=1 should have MTP head.""" + self.assertTrue(hasattr(self.model, "mtp_forward")) + self.assertTrue(hasattr(self.model, "make_mtp_cache")) + lm = self.model.language_model + self.assertTrue(hasattr(lm, "mtp")) + self.assertEqual(len(lm.mtp.layers), 1) + + def test_make_mtp_cache(self): + """make_mtp_cache should return one KVCache per MTP layer.""" + mtp_cache = self.model.make_mtp_cache() + self.assertEqual(len(mtp_cache), 1) + self.assertTrue(mtp_cache[0].is_trimmable()) + + def test_return_hidden(self): + """return_hidden=True should return (logits, hidden) with correct shapes.""" + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + out, hidden = self.model(inputs, cache=cache, return_hidden=True) + self.assertEqual(out.shape, (1, 3, 256)) + self.assertEqual(hidden.shape, (1, 3, 64)) + + def test_mtp_forward_shape(self): + """mtp_forward should produce logits of shape (B, 1, vocab).""" + hidden = mx.random.normal((1, 1, 64)) + next_ids = mx.array([[5]]) + mtp_cache = self.model.make_mtp_cache() + logits = self.model.mtp_forward(hidden, next_ids, mtp_cache) + self.assertEqual(logits.shape, (1, 1, 256)) + + def test_hidden_is_pre_norm(self): + """Hidden states returned with return_hidden should be pre-norm. + + This verifies the fix for double normalization: the backbone returns + pre-norm hidden states, and the final norm is applied only before + lm_head (not before the MTP head). + """ + lm = self.model.language_model + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + + _, hidden = lm(inputs, cache=cache, return_hidden=True) + + # Apply the final norm manually and check it changes the values. + normed = lm.model.norm(hidden) + self.assertFalse(mx.allclose(hidden, normed, atol=1e-5).item()) + + def test_quant_predicate_excludes_mtp_fc(self): + """quant_predicate should exclude mtp.fc from quantization.""" + lm = self.model.language_model + predicate = lm.quant_predicate + self.assertIsNotNone(predicate) + # mtp.fc should not be quantized + self.assertFalse(predicate("mtp.fc", None)) + # Regular layers should be quantized + self.assertTrue(predicate("layers.0.mlp.gate_proj", None)) + + def test_mtp_generate_identity(self): + """mtp_generate_step should produce the same greedy tokens as generate_step. + + This is the most important correctness test: it proves that the + draft/verify loop, SSM state rollback on rejection, and MTP cache + management are all correct. Any bug in these would cause the MTP + path to diverge from standard generation. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def greedy(logprobs): + return mx.argmax(logprobs, axis=-1) + + # Standard generation + std_cache = make_prompt_cache(self.model) + std_tokens = [] + for i, (tok, _) in enumerate( + generate_step(prompt, self.model, sampler=greedy, prompt_cache=std_cache) + ): + std_tokens.append(int(tok)) + if i + 1 >= n_tokens: + break + + # MTP generation + mtp_tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, sampler=greedy, max_tokens=n_tokens + ): + mtp_tokens.append(int(tok)) + if len(mtp_tokens) >= n_tokens: + break + + self.assertEqual( + std_tokens, + mtp_tokens, + f"Token mismatch: std={std_tokens}, mtp={mtp_tokens}", + ) + + def test_mtp_generate_runs(self): + """mtp_generate_step should complete without errors. + + Exercises the full end-to-end path: prefill, backbone forward with + return_hidden, MTP draft generation, verification with n_confirmed, + SSM rollback on rejection, and MTP cache trimming. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def greedy(logprobs): + return mx.argmax(logprobs, axis=-1) + + tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, sampler=greedy, max_tokens=n_tokens + ): + tokens.append(int(tok)) + if len(tokens) >= n_tokens: + break + + self.assertEqual(len(tokens), n_tokens) + + +if __name__ == "__main__": + unittest.main() From 78622ebd2f378a0697316b5fd528b305205b23b2 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:48:59 +0100 Subject: [PATCH 07/29] fix(mtp): warn when --mtp flag is used with a model without MTP head Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect. --- mlx_lm/generate.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3932a6bdb..eacad7047 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -4,6 +4,7 @@ import contextlib import copy import functools +import warnings import json import sys import time @@ -902,6 +903,17 @@ def stream_generate( kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) token_generator = mtp_generate_step(prompt, model, **kwargs) + elif mtp: + warnings.warn( + "--mtp flag ignored: model does not have an MTP head. " + "Falling back to standard generation.", + stacklevel=2, + ) + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + token_generator = ( + (token, logprobs, False) for token, logprobs in token_generator + ) else: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) From 50faf3ed97b32c75a1a5fbbf1c5de7f5388c7c20 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 01:05:55 +0100 Subject: [PATCH 08/29] style: apply black and isort formatting --- mlx_lm/generate.py | 19 +++++++++----- mlx_lm/models/qwen3_5.py | 57 +++++++++++++++++++++++++++++----------- tests/test_mtp.py | 3 ++- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index eacad7047..91c778514 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -4,10 +4,10 @@ import contextlib import copy import functools -import warnings import json import sys import time +import warnings from collections import deque from dataclasses import dataclass from functools import partial @@ -721,7 +721,9 @@ def _process_and_sample(tokens, logits): def _step_backbone(y, n_predict=1, n_confirmed=0): """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" with mx.stream(generation_stream): - logits, hidden = model(y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed) + logits, hidden = model( + y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed + ) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) nonlocal prev_tokens @@ -789,11 +791,13 @@ def _prefill(y): y_with_draft = mx.concatenate( [y, mx.array([draft_tok.item()], mx.uint32)] ) - toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) + toks, lps, hidden = _step_backbone( + y_with_draft, n_predict=2, n_confirmed=1 + ) mx.eval(toks, draft_tok) - verify_pred = toks[0] # backbone prediction after y → verify draft - bonus_tok = toks[1] # backbone prediction after draft_tok + verify_pred = toks[0] # backbone prediction after y → verify draft + bonus_tok = toks[1] # backbone prediction after draft_tok verify_lp = lps[0] bonus_lp = lps[1] @@ -823,7 +827,10 @@ def _prefill(y): # by GatedDeltaNet after the confirmed token. # Attention layers (KVCache): trim the draft-token entry. for c in model_cache: - if hasattr(c, "rollback_state") and c.rollback_state is not None: + if ( + hasattr(c, "rollback_state") + and c.rollback_state is not None + ): conv_snap, ssm_snap = c.rollback_state c[0] = conv_snap c[1] = ssm_snap diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 327078a0f..59d0ba98c 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -166,8 +166,15 @@ def _process_chunk( k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) out, new_ssm_state = gated_delta_update( - q, k, v, a_chunk, b_chunk, - self.A_log, self.dt_bias, ssm_state, ssm_mask, + q, + k, + v, + a_chunk, + b_chunk, + self.A_log, + self.dt_bias, + ssm_state, + ssm_mask, use_kernel=not self.training, ) return out, new_conv_state, new_ssm_state @@ -192,7 +199,9 @@ def __call__( conv_state = ( cache[0] if cache is not None and cache[0] is not None - else mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) + else mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype + ) ) ssm_state = cache[1] if cache else None @@ -205,19 +214,29 @@ def __call__( mask_c = mask[:, :n_confirmed] if mask is not None else None mask_d = mask[:, n_confirmed:] if mask is not None else None out_c, conv_c, ssm_c = self._process_chunk( - qkv[:, :n_confirmed], a[:, :n_confirmed], b[:, :n_confirmed], - conv_state, ssm_state, mask_c, + qkv[:, :n_confirmed], + a[:, :n_confirmed], + b[:, :n_confirmed], + conv_state, + ssm_state, + mask_c, ) if cache is not None: cache.rollback_state = (conv_c, ssm_c) out_d, conv_f, ssm_f = self._process_chunk( - qkv[:, n_confirmed:], a[:, n_confirmed:], b[:, n_confirmed:], - conv_c, ssm_c, mask_d, + qkv[:, n_confirmed:], + a[:, n_confirmed:], + b[:, n_confirmed:], + conv_c, + ssm_c, + mask_d, ) out = mx.concatenate([out_c, out_d], axis=1) else: lengths = cache.lengths if cache is not None else None - out, conv_f, ssm_f = self._process_chunk(qkv, a, b, conv_state, ssm_state, mask, lengths=lengths) + out, conv_f, ssm_f = self._process_chunk( + qkv, a, b, conv_state, ssm_state, mask, lengths=lengths + ) if cache is not None: cache[0] = conv_f @@ -260,7 +279,9 @@ def __call__( n_confirmed: int = 0, ) -> mx.array: if self.is_linear: - r = self.linear_attn(self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed) + r = self.linear_attn( + self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed + ) else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -275,7 +296,9 @@ def __init__(self, args: TextModelArgs): super().__init__() self.self_attn = Attention(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) if args.num_experts > 0: self.mlp = SparseMoeBlock(args) else: @@ -304,9 +327,7 @@ def __init__(self, args: TextModelArgs): self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) - self.layers = [ - MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers) - ] + self.layers = [MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers)] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( @@ -364,7 +385,11 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} + kw = ( + {"n_confirmed": n_confirmed} + if layer.is_linear and n_confirmed > 0 + else {} + ) hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) return hidden_states @@ -389,7 +414,9 @@ def __call__( return_hidden: bool = False, n_confirmed: int = 0, ) -> mx.array: - hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) + hidden = self.model( + inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed + ) normed = self.model.norm(hidden) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(normed) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index ba1a0e92f..1db571bde 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -2,8 +2,9 @@ import unittest import mlx.core as mx -from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.generate import generate_step, mtp_generate_step +from mlx_lm.models.cache import make_prompt_cache def _make_qwen3_5_mtp_model(): From be3bbf3ad827ab51bcf0debf035bf473d5c7f962 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Tue, 17 Mar 2026 12:07:02 +0100 Subject: [PATCH 09/29] fix(mtp): stack per-expert MTP weights for MoE models in sanitize() MTP layers in MoE models (35B-A3B, 122B-A10B) ship unfused per-expert weights (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) whereas the backbone uses pre-fused switch_mlp format. Conversion was failing with ~768 parameters not in model. Add a stacking loop in qwen3_5_moe.py sanitize() after the backbone expert loop, mirroring the same pattern for MTP prefixes. Co-authored-by: Thump604 --- mlx_lm/models/qwen3_5_moe.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 53ab8530e..23a6216de 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -2,6 +2,8 @@ from dataclasses import dataclass +import mlx.core as mx + from .base import BaseModelArgs from .qwen3_5 import Model as Qwen3_5Model @@ -49,4 +51,20 @@ def sanitize(self, weights): f"{prefix}.experts.down_proj" ) + # Stack per-expert MTP weights into switch_mlp format. + # MTP layers use unfused per-expert weights (experts.{i}.gate_proj etc) + # unlike backbone layers which use fused gate_up_proj. + mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0) + num_experts = self.language_model.args.num_experts + for l in range(mtp_num): + prefix = f"language_model.mtp.layers.{l}.mlp" + test_key = f"{prefix}.experts.0.gate_proj.weight" + if test_key in new_weights: + for n in ["gate_proj", "up_proj", "down_proj"]: + to_join = [ + new_weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(num_experts) + ] + new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + return self.language_model.sanitize(new_weights) From ce0bcb7f2c949e68717678c71595c5743e5049c8 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sun, 22 Mar 2026 12:07:04 +0100 Subject: [PATCH 10/29] fix(mtp): raise clear error when config has MTP but weights do not When mtp_num_hidden_layers > 0 but the model weights contain no MTP parameters, the previous error was a cryptic 'Missing N parameters'. Now raises a ValueError with an actionable message. --- mlx_lm/models/qwen3_5.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 59d0ba98c..fb8694b6b 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -476,6 +476,11 @@ def sanitize(self, weights): # Keep MTP weights if this model has an MTP head; drop them otherwise if not hasattr(self, "mtp"): weights = {k: v for k, v in weights.items() if "mtp." not in k} + elif not any("mtp." in k for k in weights): + raise ValueError( + "Config specifies mtp_num_hidden_layers > 0 but the model weights " + "contain no MTP parameters. Set mtp_num_hidden_layers=0 to disable MTP." + ) if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) From 4ffc627034c773e38a7271b298748572503317f2 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 3 Apr 2026 03:13:11 +0200 Subject: [PATCH 11/29] feat(mtp): add probabilistic draft acceptance for stochastic samplers With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass distribution. For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed. This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit. Suggested by @janhilgard; implementation reference in #1085 by @Thump604. --- mlx_lm/generate.py | 24 +++++++++++++++++++----- tests/test_mtp.py | 27 ++++++++++----------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 91c778514..32357ae38 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -5,6 +5,8 @@ import copy import functools import json +import math +import random import sys import time import warnings @@ -701,6 +703,9 @@ def mtp_generate_step( model_cache = prompt_cache[:n_main] mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() + # Exact-match acceptance for greedy (sampler=None); probabilistic + # acceptance min(1, p_target/p_draft) for stochastic samplers. + _is_greedy = sampler is None sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) quantize_cache_fn = functools.partial( @@ -770,7 +775,7 @@ def _prefill(y): try: while True: if draft_tok is None: - # No pending draft — run backbone only, then generate first draft + # No pending draft: run backbone only, then generate first draft. toks, lps, hidden = _step_backbone(y, n_predict=1) mx.eval(toks) main_tok = toks[0] @@ -796,13 +801,22 @@ def _prefill(y): ) mx.eval(toks, draft_tok) - verify_pred = toks[0] # backbone prediction after y → verify draft + verify_pred = toks[0] # backbone prediction for y, used to verify draft bonus_tok = toks[1] # backbone prediction after draft_tok verify_lp = lps[0] bonus_lp = lps[1] - if verify_pred.item() == draft_tok.item(): - # Draft accepted — discard rollback snapshots. + draft_tok_id = draft_tok.item() + if _is_greedy: + accept = verify_pred.item() == draft_tok_id + else: + # Probabilistic acceptance: min(1, p_target / p_draft). + log_accept = ( + verify_lp[draft_tok_id] - draft_lp[draft_tok_id] + ).item() + accept = log_accept >= 0 or random.random() < math.exp(log_accept) + if accept: + # Draft accepted: discard rollback snapshots. for c in model_cache: if hasattr(c, "rollback_state"): c.rollback_state = None @@ -822,7 +836,7 @@ def _prefill(y): mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32) else: - # Draft rejected — roll back all caches to the state after y. + # Draft rejected: roll back all caches to the state after y. # SSM layers (ArraysCache): restore the conv/ssm snapshot saved # by GatedDeltaNet after the confirmed token. # Attention layers (KVCache): trim the draft-token entry. diff --git a/tests/test_mtp.py b/tests/test_mtp.py index 1db571bde..fb07c94b2 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -127,24 +127,19 @@ def test_mtp_generate_identity(self): prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) n_tokens = 10 - def greedy(logprobs): - return mx.argmax(logprobs, axis=-1) - - # Standard generation + # Standard generation, greedy (default sampler is argmax). std_cache = make_prompt_cache(self.model) std_tokens = [] for i, (tok, _) in enumerate( - generate_step(prompt, self.model, sampler=greedy, prompt_cache=std_cache) + generate_step(prompt, self.model, prompt_cache=std_cache) ): std_tokens.append(int(tok)) if i + 1 >= n_tokens: break - # MTP generation + # MTP generation, greedy (sampler=None uses exact-match acceptance). mtp_tokens = [] - for tok, _, _ in mtp_generate_step( - prompt, self.model, sampler=greedy, max_tokens=n_tokens - ): + for tok, _, _ in mtp_generate_step(prompt, self.model, max_tokens=n_tokens): mtp_tokens.append(int(tok)) if len(mtp_tokens) >= n_tokens: break @@ -155,22 +150,20 @@ def greedy(logprobs): f"Token mismatch: std={std_tokens}, mtp={mtp_tokens}", ) - def test_mtp_generate_runs(self): - """mtp_generate_step should complete without errors. + def test_mtp_probabilistic_acceptance_completes(self): + """mtp_generate_step should complete without errors with a stochastic sampler. - Exercises the full end-to-end path: prefill, backbone forward with - return_hidden, MTP draft generation, verification with n_confirmed, - SSM rollback on rejection, and MTP cache trimming. + Exercises the probabilistic acceptance path: min(1, p_target / p_draft). """ prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) n_tokens = 10 - def greedy(logprobs): - return mx.argmax(logprobs, axis=-1) + def stochastic(logprobs): + return mx.random.categorical(logprobs) tokens = [] for tok, _, _ in mtp_generate_step( - prompt, self.model, sampler=greedy, max_tokens=n_tokens + prompt, self.model, sampler=stochastic, max_tokens=n_tokens ): tokens.append(int(tok)) if len(tokens) >= n_tokens: From 9c734c213aa435f7eef5f339bc71b3a77737828c Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 3 Apr 2026 05:48:01 +0200 Subject: [PATCH 12/29] refactor(mtp): clean up mtp_generate_step and add dynamic MTP/batch switching --- mlx_lm/generate.py | 186 ++++++++++++++++------------------- mlx_lm/models/qwen3_5.py | 33 +++---- mlx_lm/models/qwen3_5_moe.py | 4 +- mlx_lm/server.py | 15 +-- 4 files changed, 107 insertions(+), 131 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 32357ae38..848122f66 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -720,8 +720,29 @@ def _process_and_sample(tokens, logits): for processor in logits_processors: logits = processor(tokens, logits) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - tok = sampler(logprobs) - return tok, logprobs + return sampler(logprobs), logprobs + + def _clear_rollback(): + for c in model_cache: + if hasattr(c, "rollback_state"): + c.rollback_state = None + + def _rollback_draft(): + """Restore caches to the state after the confirmed token. + + SSM layers (ArraysCache): restore the conv/ssm snapshot saved by + GatedDeltaNet after the confirmed token. + Attention layers (KVCache): trim the draft-token entry. + """ + for c in model_cache: + if hasattr(c, "rollback_state") and c.rollback_state is not None: + conv_snap, ssm_snap = c.rollback_state + c[0] = conv_snap + c[1] = ssm_snap + c.rollback_state = None + elif c.is_trimmable(): + c.trim(1) + cache.trim_prompt_cache(mtp_cache, 1) def _step_backbone(y, n_predict=1, n_confirmed=0): """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" @@ -769,101 +790,65 @@ def _prefill(y): y = _prefill(y) ntoks = 0 - draft_tok = None - draft_lp = None - - try: - while True: - if draft_tok is None: - # No pending draft: run backbone only, then generate first draft. - toks, lps, hidden = _step_backbone(y, n_predict=1) - mx.eval(toks) - main_tok = toks[0] - main_lp = lps[0] + draft_tok = draft_lp = None + + while ntoks < max_tokens: + if draft_tok is None: + # No pending draft: run backbone only, then generate first draft. + toks, lps, hidden = _step_backbone(y, n_predict=1) + mx.eval(toks) + main_tok, main_lp = toks[0], lps[0] + ntoks += 1 + yield main_tok.item(), main_lp, False + if ntoks >= max_tokens: + return + draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok) + mx.eval(draft_tok) + y = mx.array([main_tok.item()], mx.uint32) + else: + # Verify draft: run backbone over [y, draft_tok]. + # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state + # after the confirmed token y, enabling exact rollback on rejection. + y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) + toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) + mx.eval(toks, draft_tok) + + verify_pred, bonus_tok = toks[0], toks[1] + verify_lp, bonus_lp = lps[0], lps[1] + draft_tok_id = draft_tok.item() + + if _is_greedy: + accept = verify_pred.item() == draft_tok_id + else: + # Probabilistic acceptance: min(1, p_target / p_draft). + log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item() + accept = log_accept >= 0 or random.random() < math.exp(log_accept) + if accept: + _clear_rollback() ntoks += 1 - yield main_tok.item(), main_lp, False + yield draft_tok_id, draft_lp, True if ntoks >= max_tokens: - break - - draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok) + return + ntoks += 1 + yield bonus_tok.item(), bonus_lp, False + if ntoks >= max_tokens: + return + # Next draft from MTP at draft_tok's hidden state. + draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok) mx.eval(draft_tok) - y = mx.array([main_tok.item()], mx.uint32) + y = mx.array([bonus_tok.item()], mx.uint32) else: - # Verify draft: run backbone over [y, draft_tok]. - # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state - # after the confirmed token y, enabling exact rollback on rejection. - y_with_draft = mx.concatenate( - [y, mx.array([draft_tok.item()], mx.uint32)] - ) - toks, lps, hidden = _step_backbone( - y_with_draft, n_predict=2, n_confirmed=1 - ) - mx.eval(toks, draft_tok) - - verify_pred = toks[0] # backbone prediction for y, used to verify draft - bonus_tok = toks[1] # backbone prediction after draft_tok - verify_lp = lps[0] - bonus_lp = lps[1] - - draft_tok_id = draft_tok.item() - if _is_greedy: - accept = verify_pred.item() == draft_tok_id - else: - # Probabilistic acceptance: min(1, p_target / p_draft). - log_accept = ( - verify_lp[draft_tok_id] - draft_lp[draft_tok_id] - ).item() - accept = log_accept >= 0 or random.random() < math.exp(log_accept) - if accept: - # Draft accepted: discard rollback snapshots. - for c in model_cache: - if hasattr(c, "rollback_state"): - c.rollback_state = None - - ntoks += 1 - yield draft_tok.item(), draft_lp, True - if ntoks >= max_tokens: - break - - ntoks += 1 - yield bonus_tok.item(), bonus_lp, False - if ntoks >= max_tokens: - break - - # Next draft from MTP at draft_tok's hidden state - draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok) - mx.eval(draft_tok) - y = mx.array([bonus_tok.item()], mx.uint32) - else: - # Draft rejected: roll back all caches to the state after y. - # SSM layers (ArraysCache): restore the conv/ssm snapshot saved - # by GatedDeltaNet after the confirmed token. - # Attention layers (KVCache): trim the draft-token entry. - for c in model_cache: - if ( - hasattr(c, "rollback_state") - and c.rollback_state is not None - ): - conv_snap, ssm_snap = c.rollback_state - c[0] = conv_snap - c[1] = ssm_snap - c.rollback_state = None - elif c.is_trimmable(): - c.trim(1) - cache.trim_prompt_cache(mtp_cache, 1) - - ntoks += 1 - yield verify_pred.item(), verify_lp, False - if ntoks >= max_tokens: - break - - # Next draft from MTP at y's hidden state - draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred) - mx.eval(draft_tok) - y = mx.array([verify_pred.item()], mx.uint32) - finally: - pass + _rollback_draft() + verify_tok_id = verify_pred.item() + ntoks += 1 + yield verify_tok_id, verify_lp, False + if ntoks >= max_tokens: + return + # Next draft from MTP at y's hidden state. + draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred) + mx.eval(draft_tok) + y = mx.array([verify_tok_id], mx.uint32) def stream_generate( @@ -924,18 +909,13 @@ def stream_generate( kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) token_generator = mtp_generate_step(prompt, model, **kwargs) - elif mtp: - warnings.warn( - "--mtp flag ignored: model does not have an MTP head. " - "Falling back to standard generation.", - stacklevel=2, - ) - kwargs.pop("num_draft_tokens", None) - token_generator = generate_step(prompt, model, **kwargs) - token_generator = ( - (token, logprobs, False) for token, logprobs in token_generator - ) else: + if mtp: + warnings.warn( + "--mtp flag ignored: model does not have an MTP head. " + "Falling back to standard generation.", + stacklevel=2, + ) kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index fb8694b6b..9a09c0738 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -196,13 +196,13 @@ def __call__( b = self.in_proj_b(inputs) a = self.in_proj_a(inputs) - conv_state = ( - cache[0] - if cache is not None and cache[0] is not None - else mx.zeros( - (B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype + if cache is not None and cache[0] is not None: + conv_state = cache[0] + else: + conv_state = mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), + dtype=inputs.dtype, ) - ) ssm_state = cache[1] if cache else None if mask is not None: @@ -337,8 +337,6 @@ def __call__( embed_tokens: nn.Embedding, cache: Optional[Any] = None, ) -> mx.array: - # hidden_states : (B, 1, H) — backbone hidden at last accepted position - # next_token_ids: (B, 1) — t_main (main model's prediction for t+1) embeds = embed_tokens(next_token_ids) # (B, 1, H) e = self.pre_fc_norm_embedding(embeds) h = self.pre_fc_norm_hidden(hidden_states) @@ -385,12 +383,9 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - kw = ( - {"n_confirmed": n_confirmed} - if layer.is_linear and n_confirmed > 0 - else {} + hidden_states = layer( + hidden_states, mask=mask, cache=c, n_confirmed=n_confirmed ) - hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) return hidden_states @@ -435,12 +430,12 @@ def mtp_forward( """Run the MTP head and apply the shared lm_head. Args: - hidden_states: (B, 1, H) — backbone hidden state at the last position. - next_token_ids: (B, 1) — sampled main token (t_main). - mtp_cache: list of KVCache entries for the MTP transformer layer(s). + hidden_states: Backbone pre-norm hidden state, shape (B, 1, H). + next_token_ids: Sampled main token ids, shape (B, 1). + mtp_cache: KVCache entries for the MTP transformer layers. Returns: - logits: (B, 1, vocab_size) + logits of shape (B, 1, vocab_size). """ mtp_out = self.mtp( hidden_states, @@ -470,8 +465,8 @@ def sanitize(self, weights): "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() ) # Norm weights need a +1 shift only in raw HF checkpoints (detected via - # unsanitized conv1d). Already-converted MLX models (conv1d fixed) must NOT - # be shifted again — even when they contain MTP weights. + # unsanitized conv1d). Already-converted MLX models must not be shifted + # again, even when they contain MTP weights. should_shift_norm_weights = has_unsanitized_conv1d # Keep MTP weights if this model has an MTP head; drop them otherwise if not hasattr(self, "mtp"): diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 23a6216de..f1f6ec42d 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -56,8 +56,8 @@ def sanitize(self, weights): # unlike backbone layers which use fused gate_up_proj. mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0) num_experts = self.language_model.args.num_experts - for l in range(mtp_num): - prefix = f"language_model.mtp.layers.{l}.mlp" + for layer_idx in range(mtp_num): + prefix = f"language_model.mtp.layers.{layer_idx}.mlp" test_key = f"{prefix}.experts.0.gate_proj.weight" if test_key in new_weights: for n in ["gate_proj", "up_proj", "down_proj"]: diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 8ebdab76f..d843b0261 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -378,12 +378,6 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): self.model = model self.tokenizer = tokenizer self.draft_model = draft_model - # MTP speculative decoding uses single-sequence generation - # (draft/verify loop is incompatible with batch generation). - # TODO: dynamically switch between MTP (1 request) and - # BatchGenerator (>= 2 concurrent requests). - if self.cli_args.mtp and hasattr(model, "mtp_forward"): - is_batchable = False self.is_batchable = is_batchable def load_default(self): @@ -816,7 +810,14 @@ def get_next_request(timeout=None): rqueue.put(e) continue - if not self._is_batchable(args): + # Prefer single-sequence MTP when the queue is empty; + # fall back to BatchGenerator when requests are queued. + mtp_active = getattr(self.cli_args, "mtp", False) and hasattr( + model, "mtp_forward" + ) + if not self._is_batchable(args) or ( + mtp_active and self.requests.empty() + ): self._serve_single((rqueue, request, args)) continue From 77f616d3e3dd6aa83d122bf1893bf4a8015aec2d Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sat, 4 Apr 2026 03:32:24 +0200 Subject: [PATCH 13/29] fix(mtp): always leave 1 token for _step_backbone in _prefill The _prefill loop in mtp_generate_step previously stopped when y.size <= prefill_step_size (512), leaving up to 512 tokens for _step_backbone(..., return_hidden=True). Since return_hidden=True keeps the full hidden state [1, N, d_model] live, N > 1 caused unnecessary memory pressure on longer prompts. The loop now stops at exactly 1 token (matching generate_step's design), ensuring the hidden state is always [1, 1, d_model]. Default prefill_step_size raised from 512 to 2048 accordingly. --- mlx_lm/generate.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 848122f66..c97c2e106 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -671,7 +671,7 @@ def mtp_generate_step( sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, prompt_cache: Optional[Any] = None, - prefill_step_size: int = 512, + prefill_step_size: int = 2048, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, @@ -778,11 +778,14 @@ def _step_mtp(hidden_last, main_tok): return draft_tok, draft_lp def _prefill(y): - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=model_cache) + # Leave exactly 1 token for _step_backbone: return_hidden=True keeps + # the hidden state [1, N, d_model] live, so N must be 1. + while y.size > 1: + n = min(prefill_step_size, y.size - 1) + model(y[:n][None], cache=model_cache) quantize_cache_fn(model_cache) mx.eval([c.state for c in model_cache if hasattr(c, "state")]) - y = y[prefill_step_size:] + y = y[n:] mx.clear_cache() return y From 67c10e6314029b92d735f99f73a0ff5e1f88a926 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sat, 25 Apr 2026 18:35:41 +0200 Subject: [PATCH 14/29] fix(mtp): correct prev_tokens management for logits processors Three bugs caused logits processors to receive stale token context in mtp_generate_step, breaking any processor that reads tokens[-1]: 1. _step_backbone used a fixed y_ctx slice on every loop iteration for n_predict=2, adding y[0] twice instead of y[0] then y[1]. The bonus token was therefore sampled with the wrong context. 2. _step_mtp passed prev_tokens directly to _process_and_sample without including main_tok, so tokens[-1] was the input token of the preceding backbone step rather than the just-sampled token. Fixed with a local tokens_for_proc that appends main_tok without mutating prev_tokens (mutating would double-count main_tok when _step_backbone adds y[0] at the next verify pass). 3. On draft rejection, prev_tokens retained the rejected draft token added by _step_backbone at i=1, corrupting context for the subsequent _step_mtp call. All three changes are gated on 'if logits_processors:' and have no effect on the no-processor path (benchmarks unchanged). Add two regression tests: - test_mtp_generate_identity_with_logits_processor: verifies that mtp_generate_step and generate_step produce identical greedy output under a context-sensitive stateless processor (covers bugs 1 and 3). - test_mtp_processor_prev_tokens_correct_at_draft_step: a forcing processor deterministically sets T0=4 and verifies the MTP head receives T0 as tokens[-1], not the preceding prompt token (bug 2). --- mlx_lm/generate.py | 17 ++++++--- tests/test_mtp.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index c97c2e106..4d96d10de 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -754,13 +754,12 @@ def _step_backbone(y, n_predict=1, n_confirmed=0): quantize_cache_fn(model_cache) nonlocal prev_tokens toks, lps = [], [] - y_ctx = y if n_predict == 1 else y[: -(n_predict - 1)] for i in range(n_predict): if logits_processors: prev_tokens = ( - mx.concatenate([prev_tokens, y_ctx]) + mx.concatenate([prev_tokens, y[i : i + 1]]) if prev_tokens is not None - else y_ctx + else y[i : i + 1] ) tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) toks.append(tok) @@ -774,7 +773,15 @@ def _step_mtp(hidden_last, main_tok): mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) quantize_cache_fn(mtp_cache) mtp_logits = mtp_logits[:, -1, :].squeeze(0) - draft_tok, draft_lp = _process_and_sample(prev_tokens, mtp_logits) + if logits_processors: + tokens_for_proc = ( + mx.concatenate([prev_tokens, main_tok.reshape(-1)]) + if prev_tokens is not None + else main_tok.reshape(-1) + ) + else: + tokens_for_proc = prev_tokens + draft_tok, draft_lp = _process_and_sample(tokens_for_proc, mtp_logits) return draft_tok, draft_lp def _prefill(y): @@ -843,6 +850,8 @@ def _prefill(y): y = mx.array([bonus_tok.item()], mx.uint32) else: _rollback_draft() + if logits_processors and prev_tokens is not None: + prev_tokens = prev_tokens[:-1] # discard rejected draft token verify_tok_id = verify_pred.item() ntoks += 1 yield verify_tok_id, verify_lp, False diff --git a/tests/test_mtp.py b/tests/test_mtp.py index fb07c94b2..64068a203 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -171,6 +171,92 @@ def stochastic(logprobs): self.assertEqual(len(tokens), n_tokens) + def test_mtp_generate_identity_with_logits_processor(self): + """mtp_generate_step must produce the same greedy tokens as generate_step + when a context-sensitive stateless processor is applied. + + A processor that boosts (tokens[-1] + 1) % vocab biases sampling based on + the last token. Incorrect prev_tokens management in the verify pass would + cause the bonus token or the token after a rejection to be sampled with + the wrong bias, producing a sequence that diverges from serial generation. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def context_processor(tokens, logits): + if tokens is None or tokens.size == 0: + return logits + target = (int(tokens[-1].item()) + 1) % logits.shape[-1] + # 1D boost broadcasts correctly for both (vocab,) and (1, vocab) logits. + boost = mx.zeros(logits.shape[-1]) + return logits + boost.at[target].add(10.0) + + std_cache = make_prompt_cache(self.model) + std_tokens = [] + for i, (tok, _) in enumerate( + generate_step( + prompt, + self.model, + prompt_cache=std_cache, + logits_processors=[context_processor], + ) + ): + std_tokens.append(int(tok)) + if i + 1 >= n_tokens: + break + + mtp_tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, + self.model, + max_tokens=n_tokens, + logits_processors=[context_processor], + ): + mtp_tokens.append(int(tok)) + if len(mtp_tokens) >= n_tokens: + break + + self.assertEqual(std_tokens, mtp_tokens) + + def test_mtp_processor_prev_tokens_correct_at_draft_step(self): + """The processor must see the just-sampled backbone token as tokens[-1] + when the MTP head runs, not the preceding input token. + + A forcing processor logs tokens[-1] on every call. When tokens[-1] equals + the last prompt token (3) it applies a large boost to token 4, guaranteeing + the backbone samples token 4 regardless of model weights. The second + processor call comes from the MTP head: if the token context is correct it + sees 4; if stale it sees 3 again. + """ + # Last prompt token is 3; the forcing processor boosts token 4 when it + # sees 3, so the backbone deterministically samples T0 = 4 regardless of weights. + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + + logged: list[int] = [] + + def forcing_processor(tokens, logits): + if tokens is not None and tokens.size > 0: + last = int(tokens[-1].item()) + logged.append(last) + if last == 3: + boost = mx.zeros(logits.shape[-1]) + return logits + boost.at[4].add(1000.0) + return logits + + for _tok, _, _ in mtp_generate_step( + prompt, + self.model, + max_tokens=2, + logits_processors=[forcing_processor], + ): + pass + + # First call (backbone): context is the last prompt token. + self.assertGreaterEqual(len(logged), 2) + self.assertEqual(logged[0], 3) + # Second call (MTP head): context must be T0 = 4, not the prompt token. + self.assertEqual(logged[1], 4) + if __name__ == "__main__": unittest.main() From b7f8aa4412b29a45dc52a6f9724c1c94c203dad8 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sat, 25 Apr 2026 20:01:58 +0200 Subject: [PATCH 15/29] refactor(mtp): thread prev_tokens explicitly through mtp_generate_step Replace nonlocal mutation with explicit parameter/return value threading. _step_backbone now takes prev_tokens as an argument and returns it as a fourth value; _step_mtp takes prev_tokens as a third argument. The main loop unpacks and passes prev_tokens explicitly at every call site. Also name the three hidden-state slice indices (hidden_at_main, hidden_at_confirmed, hidden_at_draft) so their roles are self-evident. --- mlx_lm/generate.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 4d96d10de..adf175262 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -744,15 +744,14 @@ def _rollback_draft(): c.trim(1) cache.trim_prompt_cache(mtp_cache, 1) - def _step_backbone(y, n_predict=1, n_confirmed=0): - """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" + def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0): + """Run the backbone on ``y`` and return (tokens, logprobs, hidden, prev_tokens).""" with mx.stream(generation_stream): logits, hidden = model( y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed ) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) - nonlocal prev_tokens toks, lps = [], [] for i in range(n_predict): if logits_processors: @@ -764,9 +763,9 @@ def _step_backbone(y, n_predict=1, n_confirmed=0): tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) toks.append(tok) lps.append(lp) - return mx.stack(toks), mx.stack(lps), hidden + return mx.stack(toks), mx.stack(lps), hidden, prev_tokens - def _step_mtp(hidden_last, main_tok): + def _step_mtp(hidden_last, main_tok, prev_tokens): """Run the MTP head and return (draft_token, draft_logprobs).""" next_ids = main_tok.reshape(1, 1) with mx.stream(generation_stream): @@ -805,14 +804,15 @@ def _prefill(y): while ntoks < max_tokens: if draft_tok is None: # No pending draft: run backbone only, then generate first draft. - toks, lps, hidden = _step_backbone(y, n_predict=1) + toks, lps, hidden, prev_tokens = _step_backbone(y, prev_tokens, n_predict=1) mx.eval(toks) main_tok, main_lp = toks[0], lps[0] ntoks += 1 yield main_tok.item(), main_lp, False if ntoks >= max_tokens: return - draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok) + hidden_at_main = hidden[:, -1:, :] + draft_tok, draft_lp = _step_mtp(hidden_at_main, main_tok, prev_tokens) mx.eval(draft_tok) y = mx.array([main_tok.item()], mx.uint32) else: @@ -820,7 +820,9 @@ def _prefill(y): # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state # after the confirmed token y, enabling exact rollback on rejection. y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) - toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) + toks, lps, hidden, prev_tokens = _step_backbone( + y_with_draft, prev_tokens, n_predict=2, n_confirmed=1 + ) mx.eval(toks, draft_tok) verify_pred, bonus_tok = toks[0], toks[1] @@ -834,6 +836,9 @@ def _prefill(y): log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item() accept = log_accept >= 0 or random.random() < math.exp(log_accept) + hidden_at_confirmed = hidden[:, 0:1, :] + hidden_at_draft = hidden[:, 1:2, :] + if accept: _clear_rollback() ntoks += 1 @@ -845,7 +850,7 @@ def _prefill(y): if ntoks >= max_tokens: return # Next draft from MTP at draft_tok's hidden state. - draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok) + draft_tok, draft_lp = _step_mtp(hidden_at_draft, bonus_tok, prev_tokens) mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32) else: @@ -858,7 +863,9 @@ def _prefill(y): if ntoks >= max_tokens: return # Next draft from MTP at y's hidden state. - draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred) + draft_tok, draft_lp = _step_mtp( + hidden_at_confirmed, verify_pred, prev_tokens + ) mx.eval(draft_tok) y = mx.array([verify_tok_id], mx.uint32) From 48e1fca559a2c5fb336289dbb2c79f5fab1cd916 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sun, 26 Apr 2026 03:16:26 +0200 Subject: [PATCH 16/29] refactor(cache): declare rollback_state as class attribute on ArraysCache --- mlx_lm/models/cache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 745f5ebfe..340079f2d 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -592,13 +592,14 @@ def nbytes(self): class ArraysCache(_BaseCache): + # Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens + # in an MTP draft-verification step. Cleared after each step. + rollback_state: Optional[tuple] = None + def __new__(cls, *args, **kwargs): instance = super().__new__(cls) instance.left_padding = None instance.lengths = None - # Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens - # in an MTP draft-verification step. Cleared after each step. - instance.rollback_state = None return instance def __init__(self, size, left_padding: Optional[List[int]] = None): From 8a5237958410e93fac3bbea0ba60837784279060 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Wed, 29 Apr 2026 17:06:57 +0200 Subject: [PATCH 17/29] fix(mtp): support input_embeddings in mtp_generate_step and fix logits_processors dimensionality Pass input_embeddings through _prefill for VLM prefill compatibility. Wrap/unwrap logits as 2D in _process_and_sample so logits_processors receive the expected [1, vocab] shape. --- mlx_lm/generate.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index adf175262..c9232328d 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -675,6 +675,7 @@ def mtp_generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + input_embeddings: Optional[mx.array] = None, ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: """A generator that uses the model's native MTP head for speculative decoding. @@ -717,8 +718,10 @@ def mtp_generate_step( def _process_and_sample(tokens, logits): if logits_processors: + logits = logits[None] for processor in logits_processors: logits = processor(tokens, logits) + logits = logits.squeeze(0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) return sampler(logprobs), logprobs @@ -783,20 +786,30 @@ def _step_mtp(hidden_last, main_tok, prev_tokens): draft_tok, draft_lp = _process_and_sample(tokens_for_proc, mtp_logits) return draft_tok, draft_lp - def _prefill(y): + def _prefill(y, input_embeddings): # Leave exactly 1 token for _step_backbone: return_hidden=True keeps # the hidden state [1, N, d_model] live, so N must be 1. - while y.size > 1: - n = min(prefill_step_size, y.size - 1) - model(y[:n][None], cache=model_cache) + total = len(input_embeddings) if input_embeddings is not None else y.size + while total > 1: + n = min(prefill_step_size, total - 1) + if input_embeddings is not None: + model( + y[:n][None], + cache=model_cache, + input_embeddings=input_embeddings[:n][None], + ) + input_embeddings = input_embeddings[n:] + else: + model(y[:n][None], cache=model_cache) quantize_cache_fn(model_cache) mx.eval([c.state for c in model_cache if hasattr(c, "state")]) y = y[n:] + total -= n mx.clear_cache() return y with mx.stream(generation_stream): - y = _prefill(y) + y = _prefill(y, input_embeddings) ntoks = 0 draft_tok = draft_lp = None From fae9fa131d1ae56600403b0d31075eadc1572221 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Wed, 29 Apr 2026 18:01:06 +0200 Subject: [PATCH 18/29] fix(mtp): support both fused and per-expert MTP weights in qwen3_5_moe sanitize Extract _unfuse_experts and _stack_per_expert helpers to eliminate duplicated logic between backbone and MTP conversion. Detect MTP format once before the loop (Qwen3.6 fused gate_up_proj vs Qwen3.5 per-expert) instead of re-checking per iteration. --- mlx_lm/models/qwen3_5_moe.py | 74 ++++++++++++++++++++---------------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index f1f6ec42d..98fe47a19 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -20,6 +20,31 @@ def from_dict(cls, params): return super().from_dict(params) +def _unfuse_experts(weights, prefix): + """Split fused gate_up_proj into per-projection switch_mlp weights (Qwen3.6 format).""" + gate_up_key = f"{prefix}.experts.gate_up_proj" + if gate_up_key not in weights: + return + gate_up = weights.pop(gate_up_key) + mid = gate_up.shape[-2] // 2 + weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[..., :mid, :] + weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[..., mid:, :] + weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop( + f"{prefix}.experts.down_proj" + ) + + +def _stack_per_expert(weights, prefix, num_experts): + """Stack per-expert weights into switch_mlp format (Qwen3.5 format).""" + for n in ("gate_proj", "up_proj", "down_proj"): + weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack( + [ + weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(num_experts) + ] + ) + + class Model(Qwen3_5Model): def sanitize(self, weights): @@ -29,42 +54,27 @@ def sanitize(self, weights): continue if key.startswith("model.language_model"): key = key.replace("model.language_model", "language_model.model") - elif key.startswith("language_model."): - pass - else: + elif not key.startswith("language_model."): key = "language_model." + key new_weights[key] = value + # Backbone MoE layers always use fused gate_up_proj (both Qwen3.5 and Qwen3.6). for l in range(self.language_model.args.num_hidden_layers): - prefix = f"language_model.model.layers.{l}.mlp" - gate_up_key = f"{prefix}.experts.gate_up_proj" - if gate_up_key in new_weights: - gate_up = new_weights.pop(gate_up_key) - mid = gate_up.shape[-2] // 2 - new_weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_up[ - ..., :mid, : - ] - new_weights[f"{prefix}.switch_mlp.up_proj.weight"] = gate_up[ - ..., mid:, : - ] - new_weights[f"{prefix}.switch_mlp.down_proj.weight"] = new_weights.pop( - f"{prefix}.experts.down_proj" - ) - - # Stack per-expert MTP weights into switch_mlp format. - # MTP layers use unfused per-expert weights (experts.{i}.gate_proj etc) - # unlike backbone layers which use fused gate_up_proj. + _unfuse_experts(new_weights, f"language_model.model.layers.{l}.mlp") + + # MTP layers: fused format (Qwen3.6) or per-expert format (Qwen3.5). + # Detect format once from the first layer and apply uniformly. mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0) - num_experts = self.language_model.args.num_experts - for layer_idx in range(mtp_num): - prefix = f"language_model.mtp.layers.{layer_idx}.mlp" - test_key = f"{prefix}.experts.0.gate_proj.weight" - if test_key in new_weights: - for n in ["gate_proj", "up_proj", "down_proj"]: - to_join = [ - new_weights.pop(f"{prefix}.experts.{e}.{n}.weight") - for e in range(num_experts) - ] - new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + if mtp_num > 0: + num_experts = self.language_model.args.num_experts + mtp_is_fused = ( + "language_model.mtp.layers.0.mlp.experts.gate_up_proj" in new_weights + ) + for layer_idx in range(mtp_num): + prefix = f"language_model.mtp.layers.{layer_idx}.mlp" + if mtp_is_fused: + _unfuse_experts(new_weights, prefix) + else: + _stack_per_expert(new_weights, prefix, num_experts) return self.language_model.sanitize(new_weights) From 32fdaa38f01962ba960c6d755baa8ac1e7b4fe66 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Tue, 5 May 2026 00:57:20 +0200 Subject: [PATCH 19/29] fix(mtp): remove spurious mtp_cache trim on draft rejection On rejection, _rollback_draft trimmed mtp_cache by 1, but the 2-token backbone verification pass never writes to mtp_cache. The trim was removing the valid KV entry from the previous _step_mtp call, causing accumulated context drift after repeated rejections (lower acceptance rate over long generations). --- mlx_lm/generate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index c9232328d..2bb3ca5e3 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -745,7 +745,6 @@ def _rollback_draft(): c.rollback_state = None elif c.is_trimmable(): c.trim(1) - cache.trim_prompt_cache(mtp_cache, 1) def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0): """Run the backbone on ``y`` and return (tokens, logprobs, hidden, prev_tokens).""" From 13f157b517db784b4882b1a712b43e1754f055d8 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Tue, 5 May 2026 21:17:45 +0200 Subject: [PATCH 20/29] fix(mtp): use residual sampling on rejection at temp>0 On rejection, emit a token sampled from max(p_target - p_draft, 0) / Z instead of the backbone argmax. This guarantees the output marginal equals the target distribution exactly (Leviathan et al. 2022; Chen et al. 2023). --- mlx_lm/generate.py | 16 +++++++++++++++- tests/test_mtp.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 2bb3ca5e3..0874f359f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -870,13 +870,27 @@ def _prefill(y, input_embeddings): if logits_processors and prev_tokens is not None: prev_tokens = prev_tokens[:-1] # discard rejected draft token verify_tok_id = verify_pred.item() + if not _is_greedy: + # Sample from residual distribution max(p_target - p_draft, 0) / Z + # (Leviathan et al. 2022 §2.3; Chen et al. 2023). Guarantees the + # output marginal equals the target distribution exactly. + p_target = mx.exp(verify_lp) + p_draft = mx.exp(draft_lp) + residual = mx.maximum(p_target - p_draft, 0.0) + z = residual.sum().item() + if z > 1e-8: + verify_tok_id = mx.random.categorical( + mx.log(residual / z + 1e-10).reshape(1, -1) + ).item() ntoks += 1 yield verify_tok_id, verify_lp, False if ntoks >= max_tokens: return # Next draft from MTP at y's hidden state. draft_tok, draft_lp = _step_mtp( - hidden_at_confirmed, verify_pred, prev_tokens + hidden_at_confirmed, + mx.array([verify_tok_id], mx.uint32), + prev_tokens, ) mx.eval(draft_tok) y = mx.array([verify_tok_id], mx.uint32) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index 64068a203..d60f605fc 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -257,6 +257,41 @@ def forcing_processor(tokens, logits): # Second call (MTP head): context must be T0 = 4, not the prompt token. self.assertEqual(logged[1], 4) + def test_mtp_rejection_residual_sampling(self): + """On rejection at temp>0, the emitted token must be sampled from the + residual distribution max(p_target - p_draft, 0) / Z, not the backbone + argmax. The argmax is deterministic for a fixed model state, so rejection + tokens would always be identical. Residual sampling produces a + distribution, so tokens must vary across runs. + + Using max_tokens=1 yields exactly one token per run: the accepted draft + (from_draft=True) or the rejection token (from_draft=False). This avoids + conflating rejection tokens with bonus tokens (also from_draft=False). + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + + def stochastic(logprobs): + return mx.random.categorical(logprobs) + + rejection_tokens: list[int] = [] + for _ in range(60): + for tok, _, from_draft in mtp_generate_step( + prompt, self.model, sampler=stochastic, max_tokens=1 + ): + if not from_draft: + rejection_tokens.append(int(tok)) + + self.assertGreaterEqual( + len(rejection_tokens), + 5, + "Too few rejection events observed; increase n_runs", + ) + self.assertGreater( + len(set(rejection_tokens)), + 1, + "Rejection tokens are always identical, argmax bug likely present", + ) + if __name__ == "__main__": unittest.main() From 65943483e46cf4f9471fefbc82acae4a1b371285 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 7 May 2026 01:40:44 +0200 Subject: [PATCH 21/29] fix(mtp): reduce residual sampling to 1 sync, correct z=0 fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove z.item() sync: z stays in the MLX graph and is evaluated once alongside categorical(), reducing Metal round-trips from 2 to 1. - Replace if z > 1e-8 guard with mx.where(z > 0, residual, p_target): when the residual mass is zero, sample from p_target instead of keeping verify_pred (argmax). Matches Leviathan et al. 2022 §2.3. --- mlx_lm/generate.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 0874f359f..113e01413 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -877,11 +877,12 @@ def _prefill(y, input_embeddings): p_target = mx.exp(verify_lp) p_draft = mx.exp(draft_lp) residual = mx.maximum(p_target - p_draft, 0.0) - z = residual.sum().item() - if z > 1e-8: - verify_tok_id = mx.random.categorical( - mx.log(residual / z + 1e-10).reshape(1, -1) - ).item() + z = residual.sum(keepdims=True) + dist = mx.where(z > 0, residual, p_target) + # categorical treats -inf log-prob as p=0. + verify_tok_id = mx.random.categorical( + mx.log(dist).reshape(1, -1) + ).item() ntoks += 1 yield verify_tok_id, verify_lp, False if ntoks >= max_tokens: From 87f1b09ccf1c6ec1b80f2863c9d769a5f01cb63e Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 7 May 2026 02:15:42 +0200 Subject: [PATCH 22/29] feat(mtp): native sampling params, XTC draw sharing, correct lp_accept Replace sampler= callable with explicit sampling params (temp, top_p, top_k, min_p, xtc_*) so mtp_generate_step can compute temperature-adjusted lp_accept for correct probabilistic acceptance at temp > 0. - Extract make_sampler_chain from make_sampler (DRY); mtp uses it directly to build the filter chain without a pre-assembled sampler. - Compute lp_accept from the filtered+scaled distribution so it matches the distribution the token was drawn from. - Share the XTC boolean draw across draft and verify steps via xtc_cell, so both steps apply the same XTC mask. - Draw acceptance coin as mx.random.uniform(), evaluated in parallel with the verify forward pass (amortized Metal dispatch, consistent with mx.random.seed()). - Fix _xtc_special_tokens: use tokenizer.eos_token_ids (plural) and concatenate properly instead of mixing int and list. - Update tests: remove sampler= from MTP tests, add top_k variant, extract _collect_rejection_tokens/_assert_residual_varies helpers. --- mlx_lm/generate.py | 147 ++++++++++++++++++++++++++++++++--------- mlx_lm/sample_utils.py | 60 +++++++++++++---- mlx_lm/server.py | 16 +++-- tests/test_mtp.py | 68 ++++++++++--------- 4 files changed, 211 insertions(+), 80 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 113e01413..fc5891df8 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -6,7 +6,6 @@ import functools import json import math -import random import sys import time import warnings @@ -41,7 +40,7 @@ TokenBuffer, load_prompt_cache, ) -from .sample_utils import make_sampler +from .sample_utils import categorical_sampling, make_sampler, make_sampler_chain from .tokenizer_utils import TokenizerWrapper from .utils import does_model_support_input_embeddings, load @@ -668,7 +667,6 @@ def mtp_generate_step( model: nn.Module, *, max_tokens: int = 256, - sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, prompt_cache: Optional[Any] = None, prefill_step_size: int = 2048, @@ -676,6 +674,14 @@ def mtp_generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, input_embeddings: Optional[mx.array] = None, + temp: float = 0.0, + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: """A generator that uses the model's native MTP head for speculative decoding. @@ -704,10 +710,21 @@ def mtp_generate_step( model_cache = prompt_cache[:n_main] mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() - # Exact-match acceptance for greedy (sampler=None); probabilistic - # acceptance min(1, p_target/p_draft) for stochastic samplers. - _is_greedy = sampler is None - sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + _is_greedy = temp == 0 + + _filter_chain, _xtc_cell = ( + make_sampler_chain( + top_p, + top_k, + min_p, + min_tokens_to_keep, + xtc_probability, + xtc_threshold, + xtc_special_tokens, + ) + if not _is_greedy + else ([], None) + ) quantize_cache_fn = functools.partial( maybe_quantize_kv_cache, @@ -716,14 +733,31 @@ def mtp_generate_step( kv_bits=kv_bits, ) - def _process_and_sample(tokens, logits): + def _process_and_sample(tokens, logits, xtc_draw=None): if logits_processors: logits = logits[None] for processor in logits_processors: logits = processor(tokens, logits) logits = logits.squeeze(0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - return sampler(logprobs), logprobs + if _filter_chain: + if _xtc_cell is not None: + _xtc_cell[0] = xtc_draw # None = fresh draw; mx.array = shared draw + masked = logprobs + for f in _filter_chain: + masked = f(masked) + token = categorical_sampling(masked, temp) + # lp_accept must reflect the same filtered distribution as token. + scaled = masked / temp + lp_accept = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) + elif _is_greedy: + token = mx.argmax(logprobs, axis=-1) + lp_accept = logprobs + else: + token = categorical_sampling(logprobs, temp) + scaled = logprobs / temp + lp_accept = scaled - mx.logsumexp(scaled, axis=-1, keepdims=True) + return token, logprobs, lp_accept def _clear_rollback(): for c in model_cache: @@ -746,15 +780,15 @@ def _rollback_draft(): elif c.is_trimmable(): c.trim(1) - def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0): - """Run the backbone on ``y`` and return (tokens, logprobs, hidden, prev_tokens).""" + def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0, xtc_draw=None): + """Run the backbone on ``y`` and return (tokens, logprobs, accept_lps, hidden, prev_tokens).""" with mx.stream(generation_stream): logits, hidden = model( y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed ) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) - toks, lps = [], [] + toks, lps, accept_lps = [], [], [] for i in range(n_predict): if logits_processors: prev_tokens = ( @@ -762,13 +796,24 @@ def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0): if prev_tokens is not None else y[i : i + 1] ) - tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) + # Pass the shared XTC draw only for position 0 (the verify position). + draw = xtc_draw if i == 0 else None + tok, lp, alp = _process_and_sample( + prev_tokens, logits[:, i, :].squeeze(0), draw + ) toks.append(tok) lps.append(lp) - return mx.stack(toks), mx.stack(lps), hidden, prev_tokens + accept_lps.append(alp) + return ( + mx.stack(toks), + mx.stack(lps), + mx.stack(accept_lps), + hidden, + prev_tokens, + ) def _step_mtp(hidden_last, main_tok, prev_tokens): - """Run the MTP head and return (draft_token, draft_logprobs).""" + """Run the MTP head and return (draft_token, draft_logprobs, draft_accept_lp, xtc_draw).""" next_ids = main_tok.reshape(1, 1) with mx.stream(generation_stream): mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) @@ -782,8 +827,12 @@ def _step_mtp(hidden_last, main_tok, prev_tokens): ) else: tokens_for_proc = prev_tokens - draft_tok, draft_lp = _process_and_sample(tokens_for_proc, mtp_logits) - return draft_tok, draft_lp + # Draw the XTC boolean once here so the verify step can reuse it. + xtc_draw = mx.random.uniform() if _xtc_cell is not None else None + draft_tok, draft_lp, draft_accept_lp = _process_and_sample( + tokens_for_proc, mtp_logits, xtc_draw + ) + return draft_tok, draft_lp, draft_accept_lp, xtc_draw def _prefill(y, input_embeddings): # Leave exactly 1 token for _step_backbone: return_hidden=True keeps @@ -811,12 +860,14 @@ def _prefill(y, input_embeddings): y = _prefill(y, input_embeddings) ntoks = 0 - draft_tok = draft_lp = None + draft_tok = draft_lp = draft_accept_lp = draft_xtc_draw = None while ntoks < max_tokens: if draft_tok is None: # No pending draft: run backbone only, then generate first draft. - toks, lps, hidden, prev_tokens = _step_backbone(y, prev_tokens, n_predict=1) + toks, lps, accept_lps, hidden, prev_tokens = _step_backbone( + y, prev_tokens, n_predict=1 + ) mx.eval(toks) main_tok, main_lp = toks[0], lps[0] ntoks += 1 @@ -824,7 +875,9 @@ def _prefill(y, input_embeddings): if ntoks >= max_tokens: return hidden_at_main = hidden[:, -1:, :] - draft_tok, draft_lp = _step_mtp(hidden_at_main, main_tok, prev_tokens) + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( + hidden_at_main, main_tok, prev_tokens + ) mx.eval(draft_tok) y = mx.array([main_tok.item()], mx.uint32) else: @@ -832,21 +885,29 @@ def _prefill(y, input_embeddings): # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state # after the confirmed token y, enabling exact rollback on rejection. y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) - toks, lps, hidden, prev_tokens = _step_backbone( - y_with_draft, prev_tokens, n_predict=2, n_confirmed=1 + u = mx.random.uniform() + toks, lps, accept_lps, hidden, prev_tokens = _step_backbone( + y_with_draft, + prev_tokens, + n_predict=2, + n_confirmed=1, + xtc_draw=draft_xtc_draw, ) - mx.eval(toks, draft_tok) + mx.eval(toks, draft_tok, u) verify_pred, bonus_tok = toks[0], toks[1] verify_lp, bonus_lp = lps[0], lps[1] + verify_accept_lp = accept_lps[0] draft_tok_id = draft_tok.item() if _is_greedy: accept = verify_pred.item() == draft_tok_id else: - # Probabilistic acceptance: min(1, p_target / p_draft). - log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item() - accept = log_accept >= 0 or random.random() < math.exp(log_accept) + # Probabilistic acceptance: min(1, p_target/p_draft) with temp-adjusted logprobs. + log_accept = ( + verify_accept_lp[draft_tok_id] - draft_accept_lp[draft_tok_id] + ).item() + accept = log_accept >= 0 or u.item() < math.exp(log_accept) hidden_at_confirmed = hidden[:, 0:1, :] hidden_at_draft = hidden[:, 1:2, :] @@ -862,7 +923,9 @@ def _prefill(y, input_embeddings): if ntoks >= max_tokens: return # Next draft from MTP at draft_tok's hidden state. - draft_tok, draft_lp = _step_mtp(hidden_at_draft, bonus_tok, prev_tokens) + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( + hidden_at_draft, bonus_tok, prev_tokens + ) mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32) else: @@ -874,8 +937,9 @@ def _prefill(y, input_embeddings): # Sample from residual distribution max(p_target - p_draft, 0) / Z # (Leviathan et al. 2022 §2.3; Chen et al. 2023). Guarantees the # output marginal equals the target distribution exactly. - p_target = mx.exp(verify_lp) - p_draft = mx.exp(draft_lp) + # Both distributions are temperature-adjusted to match sampling. + p_target = mx.exp(verify_accept_lp) + p_draft = mx.exp(draft_accept_lp) residual = mx.maximum(p_target - p_draft, 0.0) z = residual.sum(keepdims=True) dist = mx.where(z > 0, residual, p_target) @@ -888,7 +952,7 @@ def _prefill(y, input_embeddings): if ntoks >= max_tokens: return # Next draft from MTP at y's hidden state. - draft_tok, draft_lp = _step_mtp( + draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( hidden_at_confirmed, mx.array([verify_tok_id], mx.uint32), prev_tokens, @@ -904,6 +968,14 @@ def stream_generate( max_tokens: int = 256, draft_model: Optional[nn.Module] = None, mtp: bool = False, + temp: float = 0.0, + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -954,7 +1026,20 @@ def stream_generate( kwargs.pop("max_kv_size", None) kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) - token_generator = mtp_generate_step(prompt, model, **kwargs) + kwargs.pop("sampler", None) # mtp_generate_step does not accept sampler= + token_generator = mtp_generate_step( + prompt, + model, + temp=temp, + top_p=top_p, + top_k=top_k, + min_p=min_p, + min_tokens_to_keep=min_tokens_to_keep, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + xtc_special_tokens=xtc_special_tokens, + **kwargs, + ) else: if mtp: warnings.warn( diff --git a/mlx_lm/sample_utils.py b/mlx_lm/sample_utils.py index 05a45fc60..dd6f18d5a 100644 --- a/mlx_lm/sample_utils.py +++ b/mlx_lm/sample_utils.py @@ -2,11 +2,42 @@ import math from functools import partial -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional, Tuple import mlx.core as mx +def make_sampler_chain( + top_p: float = 0.0, + top_k: int = 0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.0, + xtc_special_tokens: List[int] = [], +) -> Tuple[List[Callable[[mx.array], mx.array]], Optional[List]]: + """Return (filter_chain, xtc_cell) for use in mtp_generate_step. + + xtc_cell is a mutable [uniform_draw] slot; set it before running the chain + to share the XTC boolean across draft and verify steps. None when XTC is off. + """ + _xtc_cell: Optional[List] = [None] if xtc_probability > 0.0 else None + chain: List[Callable[[mx.array], mx.array]] = [] + if top_p > 0 and top_p < 1.0: + chain.append(lambda x: apply_top_p(x, top_p)) + if min_p != 0.0: + chain.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) + if xtc_probability > 0.0: + chain.append( + lambda x: apply_xtc( + x, xtc_probability, xtc_threshold, xtc_special_tokens, _xtc_cell[0] + ) + ) + if top_k > 0: + chain.append(lambda x: apply_top_k(x, top_k)) + return chain, _xtc_cell + + def make_sampler( temp: float = 0.0, top_p: float = 0.0, @@ -47,17 +78,15 @@ def make_sampler( return lambda x: mx.argmax(x, axis=-1) # Create sampler chain - sampling_methods = [] - if top_p > 0 and top_p < 1.0: - sampling_methods.append(lambda x: apply_top_p(x, top_p)) - if min_p != 0.0: - sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) - if xtc_probability > 0.0: - sampling_methods.append( - lambda x: apply_xtc(x, xtc_probability, xtc_threshold, xtc_special_tokens) - ) - if top_k > 0: - sampling_methods.append(lambda x: apply_top_k(x, top_k)) + sampling_methods, _ = make_sampler_chain( + top_p, + top_k, + min_p, + min_tokens_to_keep, + xtc_probability, + xtc_threshold, + xtc_special_tokens, + ) # Apply the sampling methods def sampler(logprobs): @@ -243,6 +272,7 @@ def apply_xtc( xtc_probability: float, xtc_threshold: float, xtc_special_tokens: List[int], + p_draw: Optional[mx.array] = None, ) -> mx.array: """ Apply XTC sampling to the logits. @@ -252,6 +282,9 @@ def apply_xtc( xtc_probability (float): Probability of XTC sampling to happen for each token xtc_threshold (float): The threshold the probs need to reach for being sampled. special_tokens_ids (list(int)): List of special tokens IDs to be excluded from XTC sampling. + p_draw (mx.array, optional): Pre-drawn uniform; if None, draws fresh. + Pass the same draw to draft and verify steps to share the XTC + apply/skip decision across both forward passes. """ if not (0 <= xtc_threshold <= 0.5): raise ValueError( @@ -267,8 +300,9 @@ def apply_xtc( if xtc_special_tokens: mask[..., xtc_special_tokens] = False + draw = mx.random.uniform(0, 1) if p_draw is None else p_draw return mx.where( - mx.random.uniform(0, 1) > xtc_probability, + draw > xtc_probability, logits, mx.where(mask, -mx.inf, logits), ) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index d843b0261..b1d6214ee 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -396,6 +396,10 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): return self.model, self.tokenizer +def _xtc_special_tokens(tokenizer): + return tokenizer.encode("\n") + list(tokenizer.eos_token_ids) + + def _make_sampler(args, tokenizer): return make_sampler( args.sampling.temperature, @@ -404,10 +408,7 @@ def _make_sampler(args, tokenizer): min_p=args.sampling.min_p, xtc_probability=args.sampling.xtc_probability, xtc_threshold=args.sampling.xtc_threshold, - xtc_special_tokens=[ - tokenizer.eos_token_id, - tokenizer.encode("\n"), - ], + xtc_special_tokens=_xtc_special_tokens(tokenizer), ) @@ -993,6 +994,13 @@ def progress(tokens_processed, tokens_total): prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, mtp=getattr(self.cli_args, "mtp", False), + temp=args.sampling.temperature, + top_p=args.sampling.top_p, + top_k=args.sampling.top_k, + min_p=args.sampling.min_p, + xtc_probability=args.sampling.xtc_probability, + xtc_threshold=args.sampling.xtc_threshold, + xtc_special_tokens=_xtc_special_tokens(tokenizer), ): finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index d60f605fc..d7e6a7cc4 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -153,23 +153,24 @@ def test_mtp_generate_identity(self): def test_mtp_probabilistic_acceptance_completes(self): """mtp_generate_step should complete without errors with a stochastic sampler. - Exercises the probabilistic acceptance path: min(1, p_target / p_draft). + Exercises the probabilistic acceptance path: min(1, p_target / p_draft), + both with bare temp (no filters) and with top_k applied. """ prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) n_tokens = 10 - def stochastic(logprobs): - return mx.random.categorical(logprobs) - - tokens = [] - for tok, _, _ in mtp_generate_step( - prompt, self.model, sampler=stochastic, max_tokens=n_tokens - ): - tokens.append(int(tok)) - if len(tokens) >= n_tokens: - break - - self.assertEqual(len(tokens), n_tokens) + for kwargs in [ + {"temp": 0.7}, + {"temp": 0.7, "top_k": 10}, + ]: + tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, max_tokens=n_tokens, **kwargs + ): + tokens.append(int(tok)) + if len(tokens) >= n_tokens: + break + self.assertEqual(len(tokens), n_tokens, f"kwargs={kwargs}") def test_mtp_generate_identity_with_logits_processor(self): """mtp_generate_step must produce the same greedy tokens as generate_step @@ -257,41 +258,44 @@ def forcing_processor(tokens, logits): # Second call (MTP head): context must be T0 = 4, not the prompt token. self.assertEqual(logged[1], 4) - def test_mtp_rejection_residual_sampling(self): - """On rejection at temp>0, the emitted token must be sampled from the - residual distribution max(p_target - p_draft, 0) / Z, not the backbone - argmax. The argmax is deterministic for a fixed model state, so rejection - tokens would always be identical. Residual sampling produces a - distribution, so tokens must vary across runs. - - Using max_tokens=1 yields exactly one token per run: the accepted draft - (from_draft=True) or the rejection token (from_draft=False). This avoids - conflating rejection tokens with bonus tokens (also from_draft=False). - """ + def _collect_rejection_tokens(self, n_runs=60, **kwargs): + """Run mtp_generate_step n_runs times with max_tokens=1 and return all + rejection tokens (from_draft=False).""" prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) - - def stochastic(logprobs): - return mx.random.categorical(logprobs) - rejection_tokens: list[int] = [] - for _ in range(60): + for _ in range(n_runs): for tok, _, from_draft in mtp_generate_step( - prompt, self.model, sampler=stochastic, max_tokens=1 + prompt, self.model, max_tokens=1, **kwargs ): if not from_draft: rejection_tokens.append(int(tok)) + return rejection_tokens + def _assert_residual_varies(self, rejection_tokens, label=""): self.assertGreaterEqual( len(rejection_tokens), 5, - "Too few rejection events observed; increase n_runs", + f"{label}Too few rejection events observed; increase n_runs", ) self.assertGreater( len(set(rejection_tokens)), 1, - "Rejection tokens are always identical, argmax bug likely present", + f"{label}Rejection tokens are always identical, argmax bug likely present", ) + def test_mtp_rejection_residual_sampling(self): + """On rejection at temp>0, the emitted token must be sampled from the + residual distribution max(p_target - p_draft, 0) / Z, not the backbone + argmax. The argmax is deterministic for a fixed model state, so rejection + tokens would always be identical. Residual sampling produces a + distribution, so tokens must vary across runs. + + Using max_tokens=1 yields exactly one token per run: the accepted draft + (from_draft=True) or the rejection token (from_draft=False). This avoids + conflating rejection tokens with bonus tokens (also from_draft=False). + """ + self._assert_residual_varies(self._collect_rejection_tokens(temp=1.0)) + if __name__ == "__main__": unittest.main() From a2f137485ca0e09705cd6ae71c7ecda479a93de1 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 7 May 2026 16:29:59 +0200 Subject: [PATCH 23/29] quality: replace import functools with from functools import partial --- mlx_lm/generate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index fc5891df8..0a89dfcde 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -3,7 +3,6 @@ import argparse import contextlib import copy -import functools import json import math import sys @@ -384,7 +383,7 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - quantize_cache_fn = functools.partial( + quantize_cache_fn = partial( maybe_quantize_kv_cache, quantized_kv_start=quantized_kv_start, kv_group_size=kv_group_size, @@ -542,7 +541,7 @@ def speculative_generate_step( sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - quantize_cache_fn = functools.partial( + quantize_cache_fn = partial( maybe_quantize_kv_cache, quantized_kv_start=quantized_kv_start, kv_group_size=kv_group_size, @@ -726,7 +725,7 @@ def mtp_generate_step( else ([], None) ) - quantize_cache_fn = functools.partial( + quantize_cache_fn = partial( maybe_quantize_kv_cache, quantized_kv_start=quantized_kv_start, kv_group_size=kv_group_size, From b1dad14a16bb9ea8661a56f62f55877c73207561 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 7 May 2026 16:38:30 +0200 Subject: [PATCH 24/29] fix(mtp): prefill MTP cache during prompt prefill Previously _prefill only populated the backbone cache, leaving the MTP KVCache cold at the start of decode. The MTP head was trained with full prefix context, so starting from an empty cache is misaligned with training. Now each prefill chunk passes return_hidden=True and immediately calls mtp_forward(hidden, y[1:n+1], mtp_cache). The hidden tensor is transient: consumed within the same iteration before mx.clear_cache(). --- mlx_lm/generate.py | 12 +++++++----- mlx_lm/models/qwen3_5.py | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 0a89dfcde..53014662c 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -834,22 +834,24 @@ def _step_mtp(hidden_last, main_tok, prev_tokens): return draft_tok, draft_lp, draft_accept_lp, xtc_draw def _prefill(y, input_embeddings): - # Leave exactly 1 token for _step_backbone: return_hidden=True keeps - # the hidden state [1, N, d_model] live, so N must be 1. + # Leave exactly 1 token for _step_backbone so the decode loop starts clean. total = len(input_embeddings) if input_embeddings is not None else y.size while total > 1: n = min(prefill_step_size, total - 1) if input_embeddings is not None: - model( + _, hidden = model( y[:n][None], cache=model_cache, + return_hidden=True, input_embeddings=input_embeddings[:n][None], ) input_embeddings = input_embeddings[n:] else: - model(y[:n][None], cache=model_cache) + _, hidden = model(y[:n][None], cache=model_cache, return_hidden=True) + model.mtp_forward(hidden, y[1 : n + 1][None], mtp_cache) + quantize_cache_fn(mtp_cache) quantize_cache_fn(model_cache) - mx.eval([c.state for c in model_cache if hasattr(c, "state")]) + mx.eval([c.state for c in model_cache + mtp_cache if hasattr(c, "state")]) y = y[n:] total -= n mx.clear_cache() diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 9a09c0738..b95fa0ea2 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -430,12 +430,12 @@ def mtp_forward( """Run the MTP head and apply the shared lm_head. Args: - hidden_states: Backbone pre-norm hidden state, shape (B, 1, H). - next_token_ids: Sampled main token ids, shape (B, 1). + hidden_states: Backbone pre-norm hidden state (B, N, H). N=1 during decode, N>1 during prompt prefill. + next_token_ids: Next token ids, shape (B, N). mtp_cache: KVCache entries for the MTP transformer layers. Returns: - logits of shape (B, 1, vocab_size). + logits of shape (B, N, vocab_size). """ mtp_out = self.mtp( hidden_states, From ffac4333279d571b18727244f658b414bddb6ca5 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 8 May 2026 02:25:23 +0200 Subject: [PATCH 25/29] fix(mtp): clear Metal allocator cache every 256 tokens during decode generate_step calls mx.clear_cache() every 256 tokens to bound the Metal allocator's free list. Introduce _CACHE_CLEAR_INTERVAL = 256 shared by both generate_step and mtp_generate_step to add the equivalent cache-clearing logic to the MTP decode loop. The block-based counter (ntoks // _CACHE_CLEAR_INTERVAL) handles MTP iterations that could emit multiple tokens at once, where a '% interval == 0' check could skip a boundary. --- mlx_lm/generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 53014662c..d6cc6eae4 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -55,6 +55,7 @@ DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 +_CACHE_CLEAR_INTERVAL = 256 def str2bool(string): @@ -471,7 +472,7 @@ def _step(input_tokens: mx.array, input_embeddings: Optional[mx.array] = None): if n == max_tokens: break yield y.item(), logprobs - if n % 256 == 0: + if n % _CACHE_CLEAR_INTERVAL == 0: mx.clear_cache() y, logprobs = next_y, next_logprobs n += 1 @@ -861,6 +862,7 @@ def _prefill(y, input_embeddings): y = _prefill(y, input_embeddings) ntoks = 0 + last_cache_block = 0 draft_tok = draft_lp = draft_accept_lp = draft_xtc_draw = None while ntoks < max_tokens: @@ -960,6 +962,10 @@ def _prefill(y, input_embeddings): ) mx.eval(draft_tok) y = mx.array([verify_tok_id], mx.uint32) + block = ntoks // _CACHE_CLEAR_INTERVAL + if block > last_cache_block: + mx.clear_cache() + last_cache_block = block def stream_generate( From a5a82a936c9509a91dfd9bff338179f6bc58d796 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sat, 9 May 2026 20:39:15 +0200 Subject: [PATCH 26/29] style(mtp): move u after _step_backbone Declare u = mx.random.uniform() immediately before its first use (mx.eval) rather than before the unrelated _step_backbone call. --- mlx_lm/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index d6cc6eae4..d41888965 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -888,7 +888,6 @@ def _prefill(y, input_embeddings): # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state # after the confirmed token y, enabling exact rollback on rejection. y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) - u = mx.random.uniform() toks, lps, accept_lps, hidden, prev_tokens = _step_backbone( y_with_draft, prev_tokens, @@ -896,6 +895,7 @@ def _prefill(y, input_embeddings): n_confirmed=1, xtc_draw=draft_xtc_draw, ) + u = mx.random.uniform() mx.eval(toks, draft_tok, u) verify_pred, bonus_tok = toks[0], toks[1] From c47c1cb63d8d4fc2273957cd3286d70ddfae71d1 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Mon, 11 May 2026 19:51:00 +0200 Subject: [PATCH 27/29] qwen3_5: remove mtp.fc exclusion from quant_predicate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Empirical benchmarks (Qwen3.6-27B 4-bit, M4 Pro, temp=0/0.6/1.0) show no measurable impact on MTP acceptance rate when mtp.fc is quantized to 4-bit: acceptance delta is within noise (−0.2 to +0.3 pp), speedup delta within noise (−0.003 to +0.026x). Additionally, keeping mtp.fc in BF16 penalizes M1 users where BF16 has no native GPU support. --- mlx_lm/models/qwen3_5.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index b95fa0ea2..289989db8 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -504,9 +504,6 @@ def quant_predicate(self): def predicate(path, _): if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): return {"group_size": 64, "bits": 8} - # Keep the MTP fusion projection in full precision. - if path.endswith("mtp.fc"): - return False return True if self.args.num_experts <= 0 and self.args.mtp_num_hidden_layers <= 0: From 6222938fb5113eef06c8bf6a1740934ef252452f Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 15 May 2026 16:42:27 +0200 Subject: [PATCH 28/29] test(mtp): remove stale quant_predicate test The test was checking mtp.fc exclusion, which was removed in c47c1cb after empirical benchmarks. --- tests/test_mtp.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index d7e6a7cc4..9996dddf9 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -106,16 +106,6 @@ def test_hidden_is_pre_norm(self): normed = lm.model.norm(hidden) self.assertFalse(mx.allclose(hidden, normed, atol=1e-5).item()) - def test_quant_predicate_excludes_mtp_fc(self): - """quant_predicate should exclude mtp.fc from quantization.""" - lm = self.model.language_model - predicate = lm.quant_predicate - self.assertIsNotNone(predicate) - # mtp.fc should not be quantized - self.assertFalse(predicate("mtp.fc", None)) - # Regular layers should be quantized - self.assertTrue(predicate("layers.0.mlp.gate_proj", None)) - def test_mtp_generate_identity(self): """mtp_generate_step should produce the same greedy tokens as generate_step. From f840f6cff725773b29272112bedab6980e7a3c5e Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sat, 16 May 2026 00:06:08 +0200 Subject: [PATCH 29/29] fix(mtp): commit accepted draft token to mtp_cache via batched forward The accepted draft token was never processed by the MTP head, causing the cache to drift behind the backbone cache by one entry per accept. After k accepts the MTP head operates on k tokens of missing context. Empirically the impact was negligible though (backbone hidden dominates MTP head conditioning). Fix: extend _step_mtp with an optional cache_commit=(hidden, tok) parameter. When set, the alignment position and the draft position are processed in a single 2-token batched mtp_forward, committing the accepted token to mtp_cache at no extra forward-pass cost. --- mlx_lm/generate.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index d41888965..fb54053d5 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -812,9 +812,20 @@ def _step_backbone(y, prev_tokens, n_predict=1, n_confirmed=0, xtc_draw=None): prev_tokens, ) - def _step_mtp(hidden_last, main_tok, prev_tokens): - """Run the MTP head and return (draft_token, draft_logprobs, draft_accept_lp, xtc_draw).""" - next_ids = main_tok.reshape(1, 1) + def _step_mtp(hidden_last, main_tok, prev_tokens, *, cache_commit=None): + """Run the MTP head and return (draft_token, draft_logprobs, draft_accept_lp, xtc_draw). + + cache_commit: (hidden, tok) prepended as a cache-alignment position so that the + accepted draft token is committed to mtp_cache in the same batched forward. + """ + if cache_commit is not None: + align_h, align_tok = cache_commit + hidden_last = mx.concatenate([align_h, hidden_last], axis=1) + next_ids = mx.concatenate( + [align_tok.reshape(1, 1), main_tok.reshape(1, 1)], axis=1 + ) + else: + next_ids = main_tok.reshape(1, 1) with mx.stream(generation_stream): mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) quantize_cache_fn(mtp_cache) @@ -925,9 +936,13 @@ def _prefill(y, input_embeddings): yield bonus_tok.item(), bonus_lp, False if ntoks >= max_tokens: return - # Next draft from MTP at draft_tok's hidden state. + # Next draft: one batched forward aligns the cache for the + # accepted draft token and generates the next draft together. draft_tok, draft_lp, draft_accept_lp, draft_xtc_draw = _step_mtp( - hidden_at_draft, bonus_tok, prev_tokens + hidden_at_draft, + bonus_tok, + prev_tokens, + cache_commit=(hidden_at_confirmed, draft_tok), ) mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32)