diff --git a/docs/ngram-mtp-speculative-decoding-study.md b/docs/ngram-mtp-speculative-decoding-study.md new file mode 100644 index 000000000..5a5786590 --- /dev/null +++ b/docs/ngram-mtp-speculative-decoding-study.md @@ -0,0 +1,216 @@ +# N-gram + MTP Speculative Decoding Notes + +## Goal + +Test whether n-gram speculation can make long roleplay generation faster in oMLX, and whether it should be combined with MTP. + +## Short Answer + +Yes, it helps on long repeated conversations. + +Best current routing: + +```text +1. Try short used-priority n-gram draft. +2. If n-gram misses, try MTP fallback. +3. If MTP fallback is not accepting enough, disable it for the rest of the request. +4. Fall back to plain target greedy when needed. +``` + +## Current Recommended Settings + +```text +ngram_spec_enabled = true +ngram_spec_n_match = 4 +ngram_spec_draft_min = 1 +ngram_spec_draft_max = 2 +ngram_spec_min_count = 3 +ngram_spec_min_confidence = 0.8 +ngram_spec_max_entries = 2048 +ngram_spec_mtp_fallback = true +ngram_spec_mtp_adaptive = true +ngram_spec_mtp_min_cycles = 8 +ngram_spec_mtp_min_accept_rate = 0.5 +``` + +## Key Ideas + +### 1. N-gram should be short + +Long n-gram drafts caused problems. + +On the 40-turn roleplay test: + +| `draft_max` | Correct | Speed | +|---:|---:|---:| +| 1 | yes | 61.70 tok/s | +| 2 | yes | 72.50 tok/s | +| 4 | no | diverged | +| 8 | no | diverged | + +So the default should stay small: + +```text +ngram_spec_draft_max = 2 +``` + +### 2. Used n-grams should win over frequent n-grams + +Prompt frequency alone is not enough. + +Example: + +```text +Key: Archive keeper: The +Frequent prompt continuation: eastern aisle... +Current live continuation: western stair... +``` + +If the model already used `western` in this generation, that should be prioritized over the more frequent prompt branch. + +So the implementation uses: + +```text +used n-gram table first +frequency table second +``` + +### 3. MTP helps, but not everywhere + +MTP-only works on many prompts. + +Example 40-turn run: + +```text +Plain greedy: 46.67 tok/s +MTP-only: 53.09 tok/s +``` + +But n-gram helps more on repeated conversations: + +```text +N-gram target fallback: 66.15 tok/s +``` + +MTP is best used as a fallback after n-gram misses, not as the main strategy for repeated roleplay text. + +### 4. Adaptive MTP fallback is best + +MTP fallback can help, but if it starts rejecting too much, it becomes overhead. + +So we track MTP fallback accept rate per request. + +If accept rate is too low after enough cycles, MTP fallback is disabled for the rest of the request. + +## Benchmark Results + +### 40-turn roleplay benchmark + +Generation length: 320 tokens. + +| Path | Correct | wall tok/s | decode tok/s | +|---|---:|---:|---:| +| Plain greedy | yes | 48.72 | 62.67 | +| N-gram + target fallback | yes | 67.44 | 101.33 | +| N-gram + MTP fallback | yes | 68.85 | 103.87 | +| N-gram + adaptive MTP fallback | yes | 69.61 | 104.35 | + +Best result: + +```text +N-gram + adaptive MTP fallback +69.61 tok/s wall throughput +104.35 tok/s decode throughput +``` + +### Prompt-shape matrix + +| Case | Best path | Result | +|---|---|---| +| Low-repeat prose | MTP-only | small gain | +| Short repeated oath | N-gram | small gain | +| 40-turn conversation | N-gram + adaptive MTP | large gain | +| Branch-heavy repeated prompt | unsafe | speculative paths diverged | + +## Example N-gram Suggestions + +From the 40-turn roleplay prompt: + +| Key | Suggested draft | +|---|---| +| `remember the river,` | ` the tower` | +| `the river, the` | ` tower,` | +| `river, the tower` | `, and` | +| `the tower, and` | ` the name` | +| `and the name beneath` | ` the glass` | +| `Mira: The` | ` river mark` | +| `The river mark is` | ` still cold` | + +Replay stats: + +```text +315 n-gram suggestion events +313 full matches +629 drafted tokens +627 accepted-prefix tokens +``` + +## N-gram vs MTP Overlap + +Diagnostic test on the 40-turn prompt: + +```text +overlap events: 110 +first-token agree: 52 +first-token disagree: 58 +agreement rate: 47.3% +``` + +Meaning: + +- N-gram and MTP are not redundant. +- N-gram is better at exact repeated text. +- MTP is better as a local model-based fallback. + +Example agreement: + +| Key | N-gram | MTP | +|---|---|---| +| `tower, and the` | ` name beneath` | ` name` | +| `The river mark is` | ` still cold` | ` still` | + +Example disagreement: + +| Key | N-gram | MTP | +|---|---|---| +| `Mira: The` | ` river mark` | ` name` | +| `Archive keeper:` | ` Name the` | ` The` | + +## Remaining Risks + +The branch-heavy repeated prompt still diverged under speculative modes. + +So this is not yet a universal production-safe optimization for every prompt shape. + +Safe target use case: + +```text +long repeated conversation / roleplay structure +greedy decoding +short n-gram drafts +adaptive MTP fallback +``` + +## Conclusion + +N-gram speculation is useful for long roleplay conversations because the text has repeated structure. + +MTP also works, but it is better as an adaptive fallback. + +The best current policy is: + +```text +short used-priority n-gram first +adaptive MTP fallback second +plain target greedy fallback when MTP stops helping +``` diff --git a/omlx/engine/batched.py b/omlx/engine/batched.py index 12d41de16..58db4747b 100644 --- a/omlx/engine/batched.py +++ b/omlx/engine/batched.py @@ -252,6 +252,42 @@ def _load_model_sync(): self._model = apply_post_load_transforms(self._model, self._model_settings) + if self._model_settings is not None: + self._model._omlx_ngram_spec_enabled = bool( + getattr(self._model_settings, "ngram_spec_enabled", False) + ) + self._model._omlx_ngram_spec_n_match = int( + getattr(self._model_settings, "ngram_spec_n_match", 4) or 4 + ) + self._model._omlx_ngram_spec_draft_min = int( + getattr(self._model_settings, "ngram_spec_draft_min", 1) or 1 + ) + self._model._omlx_ngram_spec_draft_max = int( + getattr(self._model_settings, "ngram_spec_draft_max", 2) or 2 + ) + self._model._omlx_ngram_spec_min_count = int( + getattr(self._model_settings, "ngram_spec_min_count", 3) or 3 + ) + self._model._omlx_ngram_spec_min_confidence = float( + getattr(self._model_settings, "ngram_spec_min_confidence", 0.8) or 0.8 + ) + self._model._omlx_ngram_spec_max_entries = int( + getattr(self._model_settings, "ngram_spec_max_entries", 2048) or 2048 + ) + self._model._omlx_ngram_spec_mtp_fallback = bool( + getattr(self._model_settings, "ngram_spec_mtp_fallback", True) + ) + self._model._omlx_ngram_spec_mtp_adaptive = bool( + getattr(self._model_settings, "ngram_spec_mtp_adaptive", True) + ) + self._model._omlx_ngram_spec_mtp_min_cycles = int( + getattr(self._model_settings, "ngram_spec_mtp_min_cycles", 8) or 8 + ) + self._model._omlx_ngram_spec_mtp_min_accept_rate = float( + getattr(self._model_settings, "ngram_spec_mtp_min_accept_rate", 0.5) + or 0.5 + ) + # TurboQuant KV cache: patch attention and set kv_bits on scheduler if self._model_settings is not None: tq_enabled = getattr(self._model_settings, "turboquant_kv_enabled", False) @@ -440,7 +476,8 @@ def count_chat_tokens( messages = self._preprocess_messages(messages) template_tools = convert_tools_for_template(tools) if tools else None prompt = self._apply_chat_template( - messages, template_tools, + messages, + template_tools, chat_template_kwargs=chat_template_kwargs, is_partial=is_partial, ) @@ -675,8 +712,10 @@ async def chat( ct_kwargs = kwargs.pop("chat_template_kwargs", None) partial = kwargs.pop("is_partial", None) prompt = self._apply_chat_template( - messages, template_tools, - chat_template_kwargs=ct_kwargs, is_partial=partial, + messages, + template_tools, + chat_template_kwargs=ct_kwargs, + is_partial=partial, ) return await self.generate( @@ -735,8 +774,10 @@ async def stream_chat( ct_kwargs = kwargs.pop("chat_template_kwargs", None) partial = kwargs.pop("is_partial", None) prompt = self._apply_chat_template( - messages, template_tools, - chat_template_kwargs=ct_kwargs, is_partial=partial, + messages, + template_tools, + chat_template_kwargs=ct_kwargs, + is_partial=partial, ) # SpecPrefill: compute system prompt token count for protection. diff --git a/omlx/model_profiles.py b/omlx/model_profiles.py index 1e3a64fe6..4375cee7d 100644 --- a/omlx/model_profiles.py +++ b/omlx/model_profiles.py @@ -61,6 +61,17 @@ "vlm_mtp_enabled", "vlm_mtp_draft_model", "vlm_mtp_draft_block_size", + "ngram_spec_enabled", + "ngram_spec_n_match", + "ngram_spec_draft_min", + "ngram_spec_draft_max", + "ngram_spec_min_count", + "ngram_spec_min_confidence", + "ngram_spec_max_entries", + "ngram_spec_mtp_fallback", + "ngram_spec_mtp_adaptive", + "ngram_spec_mtp_min_cycles", + "ngram_spec_mtp_min_accept_rate", "specprefill_enabled", "specprefill_draft_model", "specprefill_keep_pct", @@ -69,18 +80,20 @@ ) # Excluded — never stored in a profile or template. -EXCLUDED_FROM_PROFILES = frozenset({ - "is_pinned", - "is_default", - "display_name", - "description", - "model_alias", - "model_type_override", - "active_profile_name", - "ttl_seconds", - # Security flag must be explicit per model — never propagated via profiles. - "trust_remote_code", -}) +EXCLUDED_FROM_PROFILES = frozenset( + { + "is_pinned", + "is_default", + "display_name", + "description", + "model_alias", + "model_type_override", + "active_profile_name", + "ttl_seconds", + # Security flag must be explicit per model — never propagated via profiles. + "trust_remote_code", + } +) def filter_universal_fields(data: dict[str, Any]) -> dict[str, Any]: diff --git a/omlx/model_settings.py b/omlx/model_settings.py index 4140fcbc8..bffe304ce 100644 --- a/omlx/model_settings.py +++ b/omlx/model_settings.py @@ -88,6 +88,17 @@ class ModelSettings: "gemma4_assistant" model. vlm_mtp_draft_model: Path/repo of the assistant drafter (e.g. "gemma-4-26B-A4B-it-assistant"). vlm_mtp_draft_block_size: Tokens drafted per round (None = mlx-vlm default). + ngram_spec_enabled: Enable draftless n-gram speculation before native MTP. + ngram_spec_n_match: Number of recent tokens used as the n-gram lookup key. + ngram_spec_draft_min: Minimum n-gram draft length required before verification. + ngram_spec_draft_max: Maximum n-gram draft length to verify in one forward pass. + ngram_spec_min_count: Minimum repeated occurrences before an n-gram key is usable. + ngram_spec_min_confidence: Minimum most-common follower ratio before a key is usable. + ngram_spec_max_entries: Maximum frequently repeated n-gram keys kept for drafting. + ngram_spec_mtp_fallback: Use native MTP behind n-gram misses when available. + ngram_spec_mtp_adaptive: Disable MTP fallback for a request if local accept rate is low. + ngram_spec_mtp_min_cycles: Minimum fallback cycles before adaptive disabling can trigger. + ngram_spec_mtp_min_accept_rate: Minimum fallback accept rate to keep using MTP. is_pinned: Keep model loaded in memory. is_default: Use this model when no model is specified. display_name: Human-readable name for UI display. @@ -107,25 +118,43 @@ class ModelSettings: force_sampling: bool = False max_tool_result_tokens: Optional[int] = None chat_template_kwargs: Optional[Dict[str, Any]] = None - forced_ct_kwargs: Optional[list[str]] = None # Keys that cannot be overridden by API requests + forced_ct_kwargs: Optional[list[str]] = ( + None # Keys that cannot be overridden by API requests + ) ttl_seconds: Optional[int] = None # Auto-unload after idle seconds (None = no TTL) - model_type_override: Optional[str] = None # "llm", "vlm", "embedding", "reranker", or None (auto-detect) - model_alias: Optional[str] = None # API-visible name (alternative to directory name) - index_cache_freq: Optional[int] = None # IndexCache: every Nth layer keeps indexer (DSA models only) - enable_thinking: Optional[bool] = None # Explicit toggle for thinking/reasoning mode (None = auto) - preserve_thinking: Optional[bool] = None # Keep blocks in historical turns (None = auto, True when template supports it) + model_type_override: Optional[str] = ( + None # "llm", "vlm", "embedding", "reranker", or None (auto-detect) + ) + model_alias: Optional[str] = ( + None # API-visible name (alternative to directory name) + ) + index_cache_freq: Optional[int] = ( + None # IndexCache: every Nth layer keeps indexer (DSA models only) + ) + enable_thinking: Optional[bool] = ( + None # Explicit toggle for thinking/reasoning mode (None = auto) + ) + preserve_thinking: Optional[bool] = ( + None # Keep blocks in historical turns (None = auto, True when template supports it) + ) thinking_budget_enabled: bool = False thinking_budget_tokens: Optional[int] = None - reasoning_parser: Optional[str] = None # xgrammar builtin name: "qwen", "harmony", "llama", etc. + reasoning_parser: Optional[str] = ( + None # xgrammar builtin name: "qwen", "harmony", "llama", etc. + ) # TurboQuant KV cache (mlx-vlm backend) turboquant_kv_enabled: bool = False turboquant_kv_bits: float = 4 # 2, 2.5, 3, 3.5, 4, 6, 8 - turboquant_skip_last: bool = True # Skip last KVCache layer (prevents corruption on sensitive models) + turboquant_skip_last: bool = ( + True # Skip last KVCache layer (prevents corruption on sensitive models) + ) # SpecPrefill (experimental: attention-based sparse prefill for MoE models) specprefill_enabled: bool = False - specprefill_draft_model: Optional[str] = None # Path to draft model (must share tokenizer) + specprefill_draft_model: Optional[str] = ( + None # Path to draft model (must share tokenizer) + ) specprefill_keep_pct: Optional[float] = None # Keep rate (0.1-0.5, default 0.2) specprefill_threshold: Optional[int] = None # Min tokens to trigger (default 8192) @@ -136,13 +165,21 @@ class ModelSettings: dflash_draft_quant_weight_bits: Optional[int] = None # 2, 4, 8 dflash_draft_quant_activation_bits: Optional[int] = None # 16, 32 dflash_draft_quant_group_size: Optional[int] = None # 32, 64, 128 - dflash_max_ctx: Optional[int] = None # None = unlimited; trigger BatchedEngine fallback when prompt_len >= this + dflash_max_ctx: Optional[int] = ( + None # None = unlimited; trigger BatchedEngine fallback when prompt_len >= this + ) # DFlash prefix cache (private to dflash; separate from omlx tiered cache because # snapshots include draft model GDN state and target hidden chunks omlx never tracks) dflash_in_memory_cache: bool = True - dflash_in_memory_cache_max_entries: int = 4 # Matches dflash balanced profile default - dflash_in_memory_cache_max_bytes: int = 8 * 1024 * 1024 * 1024 # 8 GiB (balanced profile default) - dflash_ssd_cache: bool = False # Requires in-memory cache and an omlx paged SSD cache dir + dflash_in_memory_cache_max_entries: int = ( + 4 # Matches dflash balanced profile default + ) + dflash_in_memory_cache_max_bytes: int = ( + 8 * 1024 * 1024 * 1024 + ) # 8 GiB (balanced profile default) + dflash_ssd_cache: bool = ( + False # Requires in-memory cache and an omlx paged SSD cache dir + ) # DFlash runtime tuning knobs. None = let dflash-mlx pick its own DEFAULT_RUNTIME_CONFIG # value (currently window=1024, sink=64, verify_mode="adaptive"). Surfaced for long-context # agentic workloads where acceptance drops on the default sliding window. @@ -155,13 +192,32 @@ class ModelSettings: # qwen3_5*, qwen3_6*, deepseek_v4*. Mutually exclusive with dflash and turboquant. mtp_enabled: bool = False + # Draftless n-gram speculation composed with native MTP. The n-gram drafter + # runs first on single-request greedy decoding; when it cannot produce a + # useful draft the normal MTP draft path is used. + ngram_spec_enabled: bool = False + ngram_spec_n_match: int = 4 + ngram_spec_draft_min: int = 1 + ngram_spec_draft_max: int = 2 + ngram_spec_min_count: int = 3 + ngram_spec_min_confidence: float = 0.8 + ngram_spec_max_entries: int = 2048 + ngram_spec_mtp_fallback: bool = True + ngram_spec_mtp_adaptive: bool = True + ngram_spec_mtp_min_cycles: int = 8 + ngram_spec_mtp_min_accept_rate: float = 0.5 + # VLM MTP speculative decoding via external assistant drafter (mlx-vlm f96138e+). # Target = Gemma4 VLM body, drafter = "gemma-4-26B-A4B-it-assistant" # (model_type "gemma4_assistant"). Mutually exclusive with all other speculative # paths because the wrapper bypasses mlx-lm BatchGenerator at decode time. vlm_mtp_enabled: bool = False - vlm_mtp_draft_model: Optional[str] = None # Path / model id of the assistant drafter - vlm_mtp_draft_block_size: Optional[int] = None # Tokens per draft round (None = mlx-vlm default) + vlm_mtp_draft_model: Optional[str] = ( + None # Path / model id of the assistant drafter + ) + vlm_mtp_draft_block_size: Optional[int] = ( + None # Tokens per draft round (None = mlx-vlm default) + ) # Model management flags is_pinned: bool = False @@ -192,6 +248,28 @@ def __post_init__(self) -> None: "mtp_enabled and turboquant_kv_enabled cannot both be True; " "TurboQuant patches the attention path that MTP relies on" ) + if self.ngram_spec_enabled and not self.mtp_enabled: + raise ValueError( + "ngram_spec_enabled requires mtp_enabled=True; the current " + "implementation composes n-gram drafts with the native MTP " + "BatchGenerator path" + ) + if self.ngram_spec_n_match < 1: + raise ValueError("ngram_spec_n_match must be >= 1") + if self.ngram_spec_draft_min < 1: + raise ValueError("ngram_spec_draft_min must be >= 1") + if self.ngram_spec_draft_max < self.ngram_spec_draft_min: + raise ValueError("ngram_spec_draft_max must be >= ngram_spec_draft_min") + if self.ngram_spec_min_count < 1: + raise ValueError("ngram_spec_min_count must be >= 1") + if not 0.0 < self.ngram_spec_min_confidence <= 1.0: + raise ValueError("ngram_spec_min_confidence must be > 0 and <= 1") + if self.ngram_spec_max_entries < 1: + raise ValueError("ngram_spec_max_entries must be >= 1") + if self.ngram_spec_mtp_min_cycles < 1: + raise ValueError("ngram_spec_mtp_min_cycles must be >= 1") + if not 0.0 <= self.ngram_spec_mtp_min_accept_rate <= 1.0: + raise ValueError("ngram_spec_mtp_min_accept_rate must be >= 0 and <= 1") # vlm_mtp wraps mlx-vlm's MTP loop and bypasses mlx-lm BatchGenerator # at decode time, so it cannot coexist with any other speculative path # or with TurboQuant (which mutates the same cache objects). @@ -327,7 +405,7 @@ def _save(self) -> None: "models": { model_id: settings.to_dict() for model_id, settings in self._settings.items() - } + }, } try: @@ -488,7 +566,9 @@ def save_profile( with self._lock: per_model = self._profiles.setdefault(model_id, {}) if name in per_model: - raise ValueError(f"Profile '{name}' already exists for model '{model_id}'") + raise ValueError( + f"Profile '{name}' already exists for model '{model_id}'" + ) now = utcnow().isoformat() per_model[name] = { "name": name, diff --git a/omlx/patches/mlx_lm_mtp/batch_generator.py b/omlx/patches/mlx_lm_mtp/batch_generator.py index c751c98a5..35fbcc414 100644 --- a/omlx/patches/mlx_lm_mtp/batch_generator.py +++ b/omlx/patches/mlx_lm_mtp/batch_generator.py @@ -73,6 +73,7 @@ # Public entry point # --------------------------------------------------------------------------- + def apply() -> bool: """Wrap ``GenerationBatch.__init__`` + ``GenerationBatch.next``.""" global _PATCHED @@ -124,9 +125,7 @@ def patched_next(self, *args, **kwargs): try: return _mtp_next(self, state) except _MtpStepFallback as exc: - logger.debug( - "MTP next() fallback to standard step: %s", exc - ) + logger.debug("MTP next() fallback to standard step: %s", exc) # Best-effort: drop state so subsequent calls don't try # to resume a half-built MTP cycle from a stale snapshot. if hasattr(self, "_omlx_mtp_state"): @@ -220,6 +219,7 @@ def _is_mtp_eligible(gen_batch: Any) -> bool: return False try: from . import is_mtp_active + if not is_mtp_active(): return False except Exception: @@ -247,6 +247,7 @@ def _ineligibility_reason(gen_batch: Any) -> str: return "model has no attached mtp head" try: from . import is_mtp_active + if not is_mtp_active(): return "mtp_active flag is off (model_settings.mtp_enabled was False at load time)" except Exception: @@ -267,6 +268,7 @@ class _MtpStepFallback(RuntimeError): # State # --------------------------------------------------------------------------- + @dataclass class _MtpStats: """Acceptance / throughput counters for one MTP-active sequence. @@ -283,6 +285,15 @@ class _MtpStats: draft_emits: int = 0 # tokens emitted as accepted drafts bonus_emits: int = 0 # tokens emitted as bonus (accepted + emit_bonus) verify_emits: int = 0 # tokens emitted as verify-position correction (reject path) + ngram_emits: int = 0 # tokens emitted from fully accepted n-gram drafts + ngram_cycles: int = 0 # number of n-gram verify cycles run + ngram_accepts: int = 0 # n-gram cycles whose whole draft was accepted + ngram_rejects: int = 0 # n-gram cycles that fell back to target token + ngram_draft_tokens: int = 0 # total n-gram draft tokens verified + ngram_accepted_tokens: int = 0 # total n-gram draft tokens accepted + mtp_fallback_cycles: int = 0 # MTP verify cycles run behind n-gram misses + mtp_fallback_accepts: int = 0 # accepted MTP fallback cycles + mtp_fallback_disabled: bool = False # adaptive routing disabled MTP fallback # Component-level timings. Help diagnose where MTP overhead comes from # when accept rate is healthy but wall-clock throughput isn't. backbone_ms: float = 0.0 # cumulative time inside the 2-token verify forward @@ -321,14 +332,33 @@ class _MtpState: # GPU→CPU sync (`int(draft_tok.tolist()[0])` would force a stall). draft_id: int = -1 + # Lazy MTP refresh state. Consecutive n-gram cycles do not need an MTP + # draft; keep the confirmed hidden/token pair and build the MTP draft + # only when n-gram misses and the normal MTP verifier is actually needed. + pending_mtp_hidden: Optional[Any] = None + pending_mtp_token: Optional[Any] = None + pending_mtp_prev_buf: Optional[Any] = None + # Accept-rate / throughput counters. Surfaced via logger.info on finish. stats: _MtpStats = field(default_factory=_MtpStats) + mtp_fallback_cycles: int = 0 + mtp_fallback_accepts: int = 0 + mtp_fallback_disabled: bool = False + + # Draftless n-gram speculation state. ``ngram_used`` tracks followers that + # were confirmed during this inference and is consulted before the prompt + # frequency table in ``ngram_index``. + ngram_index: dict[tuple[int, ...], int] = field(default_factory=dict) + ngram_used: dict[tuple[int, ...], int] = field(default_factory=dict) + ngram_counts: dict[tuple[int, ...], dict[int, int]] = field(default_factory=dict) + ngram_indexed_until: int = 0 # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _get_generation_stream(): """Return the ``mlx_lm.generate`` module-level generation stream. @@ -355,10 +385,9 @@ def _resolve_sampler(gen_batch: Any): def _is_greedy(gen_batch: Any) -> bool: - """Heuristic mirroring PR 990's ``sampler is None``.""" - if gen_batch.samplers and gen_batch.samplers[0] is not None: - return False - return True + """Return True when the resolved oMLX sampler is deterministic argmax.""" + sampler = _resolve_sampler(gen_batch) + return float(getattr(sampler, "temp", 0.0) or 0.0) == 0.0 def _proc_list(gen_batch: Any) -> Optional[List[Any]]: @@ -429,7 +458,7 @@ def _trim_token_buffer(gen_batch: Any, n: int) -> None: buf._size = max(0, buf._size - n) -def _restore_or_trim_caches(prompt_cache: List[Any]) -> bool: +def _restore_or_trim_caches(prompt_cache: List[Any], trim_tokens: int = 1) -> bool: """Roll back one token from each layer cache after a draft rejection. SSM / linear-attention layers expose ``rollback_state`` populated by the @@ -447,7 +476,8 @@ def _restore_or_trim_caches(prompt_cache: List[Any]) -> bool: c.rollback_state = None continue if hasattr(c, "is_trimmable") and c.is_trimmable(): - c.trim(1) + if trim_tokens: + c.trim(trim_tokens) continue return False return True @@ -482,11 +512,10 @@ def _rollback_after_reject( non-MTP step. """ if gdn_states is not None and hasattr(model, "rollback_speculative_cache"): - model.rollback_speculative_cache( - prompt_cache, gdn_states, accepted, block_size - ) + model.rollback_speculative_cache(prompt_cache, gdn_states, accepted, block_size) return True - return _restore_or_trim_caches(prompt_cache) + trim_tokens = max(0, block_size - (accepted + 1)) + return _restore_or_trim_caches(prompt_cache, trim_tokens=trim_tokens) def _call_backbone( @@ -518,9 +547,7 @@ def _call_backbone( return result if len(result) == 2: return result[0], result[1], None - raise TypeError( - f"backbone returned unexpected shape: {type(result).__name__}" - ) + raise TypeError(f"backbone returned unexpected shape: {type(result).__name__}") def _clear_rollback(prompt_cache: List[Any]) -> None: @@ -544,6 +571,7 @@ def _ensure_uint32(arr): # emitted tokens; stash a draft for the first verify cycle. # --------------------------------------------------------------------------- + def _post_init_mtp(gen_batch: Any) -> None: """Bridge from standard ``__init__``'s ``_step()`` into PR 990's cycle 1. @@ -605,9 +633,7 @@ def _post_init_mtp(gen_batch: Any) -> None: prev_with_main_and_next = mx.concatenate( [prev_buf, _ensure_uint32(next_main_tok)] ) - mtp_logits_2d = _apply_processors( - procs, prev_with_main_and_next, mtp_logits_2d - ) + mtp_logits_2d = _apply_processors(procs, prev_with_main_and_next, mtp_logits_2d) draft_lp_2d = _logprobs(mtp_logits_2d) draft_tok = sampler(draft_lp_2d) # Filtered draft lp — what the sampler actually drew from. The next @@ -639,6 +665,7 @@ def _post_init_mtp(gen_batch: Any) -> None: # next() dispatch # --------------------------------------------------------------------------- + def _mtp_next(gen_batch: Any, state: _MtpState) -> Any: """Emit one token; run a verify cycle if the queue is empty.""" if state.queue: @@ -646,6 +673,40 @@ def _mtp_next(gen_batch: Any, state: _MtpState) -> Any: _bump_emit_stat(state, source) return _emit_response(gen_batch, token_id, logprobs_1d, state.stats) + if _try_ngram_cycle(gen_batch, state): + if state.queue: + token_id, logprobs_1d, source = state.queue.popleft() + _bump_emit_stat(state, source) + return _emit_response(gen_batch, token_id, logprobs_1d, state.stats) + + if _ngram_enabled(gen_batch) and _is_greedy(gen_batch): + use_mtp_fallback = ( + _ngram_mtp_fallback_enabled(gen_batch) and not state.mtp_fallback_disabled + ) + if use_mtp_fallback: + cycles_before = state.stats.cycles + accepts_before = state.stats.accepts + _ensure_mtp_draft(gen_batch, state) + _run_verify_cycle(gen_batch, state) + _record_used_from_queue(gen_batch, state) + _update_mtp_fallback_route( + gen_batch, + state, + cycles_before=cycles_before, + accepts_before=accepts_before, + ) + if not state.queue: + raise _MtpStepFallback("verify cycle produced no emit tokens") + token_id, logprobs_1d, source = state.queue.popleft() + _bump_emit_stat(state, source) + return _emit_response(gen_batch, token_id, logprobs_1d, state.stats) + + _run_target_greedy_cycle(gen_batch, state) + token_id, logprobs_1d, source = state.queue.popleft() + _bump_emit_stat(state, source) + return _emit_response(gen_batch, token_id, logprobs_1d, state.stats) + + _ensure_mtp_draft(gen_batch, state) _run_verify_cycle(gen_batch, state) if not state.queue: # Verify cycle should always populate the queue with at least the @@ -668,7 +729,11 @@ def _log_mtp_stats(uid: Any, stats: "_MtpStats", finish_reason: str) -> None: timing[backbone=ms mtp=ms sample=ms cache=ms] """ total_emits = ( - stats.init_emits + stats.draft_emits + stats.bonus_emits + stats.verify_emits + stats.init_emits + + stats.draft_emits + + stats.bonus_emits + + stats.verify_emits + + stats.ngram_emits ) if stats.cycles > 0: rate_str = f"{stats.accepts / stats.cycles * 100:.1f}%" @@ -676,7 +741,9 @@ def _log_mtp_stats(uid: Any, stats: "_MtpStats", finish_reason: str) -> None: rate_str = "n/a" logger.info( "MTP[%s] finish=%s tokens=%d cycles=%d accept=%d/%d (%s) " - "emits[init=%d,draft=%d,bonus=%d,verify=%d] " + "ngram[cycles=%d accept=%d reject=%d tokens=%d/%d] " + "mtp_fallback[cycles=%d accept=%d disabled=%d] " + "emits[init=%d,draft=%d,bonus=%d,verify=%d,ngram=%d] " "timing[backbone=%.1fms mtp=%.1fms sample=%.1fms cache=%.1fms]", uid, finish_reason, @@ -685,10 +752,19 @@ def _log_mtp_stats(uid: Any, stats: "_MtpStats", finish_reason: str) -> None: stats.accepts, stats.cycles, rate_str, + stats.ngram_cycles, + stats.ngram_accepts, + stats.ngram_rejects, + stats.ngram_accepted_tokens, + stats.ngram_draft_tokens, + stats.mtp_fallback_cycles, + stats.mtp_fallback_accepts, + 1 if stats.mtp_fallback_disabled else 0, stats.init_emits, stats.draft_emits, stats.bonus_emits, stats.verify_emits, + stats.ngram_emits, stats.backbone_ms, stats.mtp_head_ms, stats.sample_ms, @@ -705,12 +781,499 @@ def _bump_emit_stat(state: _MtpState, source: str) -> None: state.stats.bonus_emits += 1 elif source == "verify": state.stats.verify_emits += 1 + elif source == "ngram": + state.stats.ngram_emits += 1 + + +# --------------------------------------------------------------------------- +# Draftless n-gram speculation. This composes with native MTP by trying a +# history-derived draft first; when it misses or rejects, the normal MTP +# state is refreshed and remains the fallback for the next cycle. +# --------------------------------------------------------------------------- + + +def _ngram_enabled(gen_batch: Any) -> bool: + return bool(getattr(gen_batch.model, "_omlx_ngram_spec_enabled", False)) + + +def _ngram_mtp_fallback_enabled(gen_batch: Any) -> bool: + return bool(getattr(gen_batch.model, "_omlx_ngram_spec_mtp_fallback", True)) + + +def _ngram_mtp_adaptive_enabled(gen_batch: Any) -> bool: + return bool(getattr(gen_batch.model, "_omlx_ngram_spec_mtp_adaptive", True)) + + +def _ngram_mtp_adaptive_config(gen_batch: Any) -> tuple[int, float]: + min_cycles = max( + 1, int(getattr(gen_batch.model, "_omlx_ngram_spec_mtp_min_cycles", 8)) + ) + min_rate = min( + 1.0, + max( + 0.0, + float( + getattr(gen_batch.model, "_omlx_ngram_spec_mtp_min_accept_rate", 0.5) + ), + ), + ) + return min_cycles, min_rate + + +def _ngram_config(gen_batch: Any) -> tuple[int, int, int, int, float, int]: + n_match = max(1, int(getattr(gen_batch.model, "_omlx_ngram_spec_n_match", 4))) + draft_min = max(1, int(getattr(gen_batch.model, "_omlx_ngram_spec_draft_min", 1))) + draft_max = max( + draft_min, int(getattr(gen_batch.model, "_omlx_ngram_spec_draft_max", 2)) + ) + min_count = max(1, int(getattr(gen_batch.model, "_omlx_ngram_spec_min_count", 3))) + min_confidence = min( + 1.0, + max( + 0.0, float(getattr(gen_batch.model, "_omlx_ngram_spec_min_confidence", 0.8)) + ), + ) + max_entries = max( + 1, int(getattr(gen_batch.model, "_omlx_ngram_spec_max_entries", 2048)) + ) + return n_match, draft_min, draft_max, min_count, min_confidence, max_entries + + +def _update_ngram_index( + state: _MtpState, + history: list[int], + n_match: int, + min_count: int, + min_confidence: float, + max_entries: int, +) -> None: + """Index only repeated n-grams with a stable most-common follower.""" + last_start = len(history) - n_match - 1 + if last_start < state.ngram_indexed_until: + return + for i in range(state.ngram_indexed_until, last_start + 1): + key = tuple(history[i : i + n_match]) + follower = int(history[i + n_match]) + counts = state.ngram_counts.setdefault(key, {}) + counts[follower] = counts.get(follower, 0) + 1 + best_token, best_count = max( + counts.items(), key=lambda item: (item[1], item[0]) + ) + total = sum(counts.values()) + confidence = best_count / total if total else 0.0 + if best_count < min_count or confidence < min_confidence: + state.ngram_index.pop(key, None) + continue + if key in state.ngram_index or len(state.ngram_index) < max_entries: + state.ngram_index[key] = int(best_token) + state.ngram_indexed_until = last_start + 1 + + +def _draft_from_ngram( + state: _MtpState, + history: list[int], + n_match: int, + min_count: int, + min_confidence: float, + draft_max: int, +) -> list[int]: + """Iteratively draft by feeding predicted tokens back into the lookup key.""" + if len(history) < n_match: + return [] + draft: list[int] = [] + rolling = [int(t) for t in history[-n_match:]] + for _ in range(draft_max): + key = tuple(rolling[-n_match:]) + used_token = state.ngram_used.get(key) + if used_token is not None: + draft.append(int(used_token)) + rolling.append(int(used_token)) + continue + next_token = state.ngram_index.get(key) + if next_token is None: + break + counts = state.ngram_counts.get(key) + if not counts: + break + best_count = counts.get(next_token, 0) + total = sum(counts.values()) + if best_count < min_count or not total or best_count / total < min_confidence: + break + draft.append(int(next_token)) + rolling.append(int(next_token)) + return draft + + +def _ngram_history(gen_batch: Any) -> list[int]: + """Return prompt + emitted-token history for n-gram lookup. + + ``GenerationBatch.tokens[0]`` is not a reliable full-context source in + oMLX's external-prefill path; in practice it can contain only emitted + completion tokens. ``_token_context[0]`` is initialized from the prompt + for logits processors, so combine it with the emitted list when needed. + """ + emitted = [int(t) for t in gen_batch.tokens[0]] + try: + uid = gen_batch.uids[0] + prompt_map = getattr(gen_batch.model, "_omlx_ngram_prompt_tokens", None) + mapped_prompt = prompt_map.get(uid) if prompt_map is not None else None + if mapped_prompt: + return [int(t) for t in mapped_prompt] + emitted + except Exception: + pass + try: + prompt_tokens = [int(t) for t in gen_batch._token_context[0].tokens.tolist()] + except Exception: + return emitted + + if not prompt_tokens: + return emitted + if ( + len(emitted) >= len(prompt_tokens) + and emitted[: len(prompt_tokens)] == prompt_tokens + ): + return emitted + return prompt_tokens + emitted + + +def _defer_mtp_refresh( + gen_batch: Any, + state: _MtpState, + hidden_at_position: Any, + token: Any, + prev_buf: Optional[Any], +) -> None: + state.mtp_cache = gen_batch.model.make_mtp_cache() + state.pending_mtp_hidden = hidden_at_position + state.pending_mtp_token = _ensure_uint32(token) + state.pending_mtp_prev_buf = prev_buf + state.draft_tok = None + state.draft_lp = None + state.draft_accept_lp = None + state.draft_id = -1 + + +def _ensure_mtp_draft(gen_batch: Any, state: _MtpState) -> None: + if state.draft_tok is not None: + return + if state.pending_mtp_hidden is None or state.pending_mtp_token is None: + raise _MtpStepFallback("MTP draft unavailable") + new_draft, new_draft_lp = _step_mtp( + gen_batch, + state.pending_mtp_hidden, + state.pending_mtp_token, + prev_buf=state.pending_mtp_prev_buf, + stats=state.stats, + ) + state.draft_tok = new_draft + state.draft_lp = new_draft_lp + state.pending_mtp_hidden = None + state.pending_mtp_token = None + state.pending_mtp_prev_buf = None + + +def _mark_used_ngram( + state: _MtpState, + history: list[int], + follower: int, + n_match: int, + max_entries: int, +) -> None: + if len(history) < n_match: + return + if ( + len(state.ngram_used) >= max_entries + and tuple(history[-n_match:]) not in state.ngram_used + ): + return + state.ngram_used[tuple(history[-n_match:])] = int(follower) + + +def _mark_used_ngram_chain( + state: _MtpState, + history: list[int], + followers: list[int], + n_match: int, + max_entries: int, +) -> None: + rolling = [int(t) for t in history] + for follower in followers: + _mark_used_ngram(state, rolling, int(follower), n_match, max_entries) + rolling.append(int(follower)) + + +def _record_used_from_queue(gen_batch: Any, state: _MtpState) -> None: + if not state.queue: + return + n_match, _, _, _, _, max_entries = _ngram_config(gen_batch) + followers = [int(token_id) for token_id, _, _ in state.queue] + _mark_used_ngram_chain( + state, + _ngram_history(gen_batch), + followers, + n_match, + max_entries, + ) + + +def _update_mtp_fallback_route( + gen_batch: Any, + state: _MtpState, + cycles_before: int, + accepts_before: int, +) -> None: + cycles_delta = state.stats.cycles - cycles_before + if cycles_delta <= 0: + return + accepts_delta = state.stats.accepts - accepts_before + state.mtp_fallback_cycles += cycles_delta + state.mtp_fallback_accepts += accepts_delta + state.stats.mtp_fallback_cycles = state.mtp_fallback_cycles + state.stats.mtp_fallback_accepts = state.mtp_fallback_accepts + if not _ngram_mtp_adaptive_enabled(gen_batch): + return + min_cycles, min_rate = _ngram_mtp_adaptive_config(gen_batch) + if state.mtp_fallback_cycles < min_cycles: + return + rate = state.mtp_fallback_accepts / state.mtp_fallback_cycles + if rate < min_rate: + state.mtp_fallback_disabled = True + state.stats.mtp_fallback_disabled = True + + +def _record_high_prob_ngram( + state: _MtpState, + history: list[int], + follower: int, + n_match: int, + min_count: int, + min_confidence: float, + max_entries: int, +) -> None: + """Save a target-greedy continuation observed during inference. + + Prompt frequency gives the first n-gram table. Target-greedy fallback + gives us model-confirmed high-probability followers for the generated + branch, so future repeats in the same request can draft from the actual + branch the model is taking rather than only the most common prompt branch. + """ + if len(history) < n_match: + return + key = tuple(history[-n_match:]) + _mark_used_ngram(state, history, follower, n_match, max_entries) + counts = state.ngram_counts.setdefault(key, {}) + counts[int(follower)] = counts.get(int(follower), 0) + 1 + best_token, best_count = max(counts.items(), key=lambda item: (item[1], item[0])) + total = sum(counts.values()) + confidence = best_count / total if total else 0.0 + if best_count >= min_count and confidence >= min_confidence: + if key in state.ngram_index or len(state.ngram_index) < max_entries: + state.ngram_index[key] = int(best_token) + + +def _run_target_greedy_cycle(gen_batch: Any, state: _MtpState) -> None: + """Correct greedy fallback for n-gram mode. + + The native MTP fallback can drift on the Qwen GDN path. When n-gram mode + is enabled, use a plain target-model greedy step on misses/reject gaps and + record that model-confirmed continuation for later n-gram drafting. + """ + import time + + import mlx.core as mx + + if state.next_main is None: + raise _MtpStepFallback("target fallback entered without next_main") + + procs = _proc_list(gen_batch) + prev_buf = None + if procs is not None: + prev_buf = gen_batch._token_context[0].update_and_fetch(state.next_main) + + t0 = time.perf_counter() + with mx.stream(_get_generation_stream()): + logits, _, _ = _call_backbone( + gen_batch.model, + state.next_main[:, None], + gen_batch.prompt_cache, + ) + next_logits = logits[:, -1, :] + state.stats.backbone_ms += (time.perf_counter() - t0) * 1000 + + t0 = time.perf_counter() + if procs is not None: + next_logits = _apply_processors(procs, prev_buf, next_logits) + next_tok = mx.argmax(next_logits, axis=-1).reshape(-1) + mx.eval(next_tok) + emit_id = int(next_tok.tolist()[0]) + state.stats.sample_ms += (time.perf_counter() - t0) * 1000 + + n_match, _, _, min_count, min_confidence, max_entries = _ngram_config(gen_batch) + _record_high_prob_ngram( + state, + _ngram_history(gen_batch), + emit_id, + n_match, + min_count, + min_confidence, + max_entries, + ) + + emit_tok = _ensure_uint32(next_tok) + state.next_main = emit_tok + _defer_mtp_refresh(gen_batch, state, None, emit_tok, None) + state.queue.append((emit_id, None, "verify")) + + +def _try_ngram_cycle(gen_batch: Any, state: _MtpState) -> bool: + """Run one all-or-nothing n-gram verify cycle. + + The cache is kept only when the full n-gram draft matches target greedy + tokens. On any mismatch we roll back the entire draft suffix, emit the + target token after ``state.next_main``, and refresh the MTP draft from the + confirmed hidden state. This preserves exact greedy output while avoiding + partial-accept rollback complexity for GDN caches. + """ + import time + + import mlx.core as mx + + if not _ngram_enabled(gen_batch) or not _is_greedy(gen_batch): + return False + if state.next_main is None: + return False + + n_match, draft_min, draft_max, min_count, min_confidence, max_entries = ( + _ngram_config(gen_batch) + ) + history = _ngram_history(gen_batch) + if len(history) < n_match: + return False + + _update_ngram_index(state, history, n_match, min_count, min_confidence, max_entries) + draft_ids = _draft_from_ngram( + state, history, n_match, min_count, min_confidence, draft_max + ) + if len(draft_ids) < draft_min: + return False + + sampler = _resolve_sampler(gen_batch) + procs = _proc_list(gen_batch) + draft_arr = mx.array(draft_ids, dtype=mx.uint32) + inputs = mx.concatenate([state.next_main, draft_arr]) + + prev_buffers: list[Any] = [] + if procs is not None: + prev_buffers.append( + gen_batch._token_context[0].update_and_fetch(state.next_main) + ) + for token_id in draft_ids: + tok = mx.array([token_id], dtype=mx.uint32) + prev_buffers.append(gen_batch._token_context[0].update_and_fetch(tok)) + + t0 = time.perf_counter() + with mx.stream(_get_generation_stream()): + logits, hidden, gdn_states = _call_backbone( + gen_batch.model, + inputs[None, :], + gen_batch.prompt_cache, + n_confirmed=1, + ) + state.stats.backbone_ms += (time.perf_counter() - t0) * 1000 + + t0 = time.perf_counter() + verify_logits = logits[:, : len(draft_ids), :] + bonus_logits = logits[:, len(draft_ids), :] + if procs is not None: + processed = [] + for i in range(len(draft_ids)): + processed.append( + _apply_processors(procs, prev_buffers[i], verify_logits[:, i, :]) + ) + verify_logits = mx.stack(processed, axis=1) + bonus_logits = _apply_processors(procs, prev_buffers[-1], bonus_logits) + + verify_tok = mx.argmax(verify_logits, axis=-1).reshape(-1) + bonus_tok = mx.argmax(bonus_logits, axis=-1).reshape(-1) + mx.eval(verify_tok, bonus_tok) + + verify_ids = [int(t) for t in verify_tok.tolist()] + bonus_id = int(bonus_tok.tolist()[0]) + n_accept = 0 + for expected, actual in zip(draft_ids, verify_ids): + if expected != actual: + break + n_accept += 1 + full_accept = n_accept == len(draft_ids) + state.stats.sample_ms += (time.perf_counter() - t0) * 1000 + state.stats.ngram_cycles += 1 + state.stats.ngram_draft_tokens += len(draft_ids) + + if full_accept: + state.stats.ngram_accepts += 1 + state.stats.ngram_accepted_tokens += len(draft_ids) + _mark_used_ngram_chain( + state, + history, + draft_ids + [bonus_id], + n_match, + max_entries, + ) + t0 = time.perf_counter() + _clear_rollback(gen_batch.prompt_cache) + state.stats.cache_ops_ms += (time.perf_counter() - t0) * 1000 + + hidden_at_last_accepted = hidden[:, len(draft_ids) : len(draft_ids) + 1, :] + _defer_mtp_refresh( + gen_batch, + state, + hidden_at_last_accepted, + _ensure_uint32(bonus_tok), + prev_buffers[-1] if procs is not None else None, + ) + for i, token_id in enumerate(draft_ids): + state.queue.append((token_id, None, "ngram")) + state.queue.append((bonus_id, None, "bonus")) + state.next_main = _ensure_uint32(bonus_tok) + return True + + state.stats.ngram_rejects += 1 + t0 = time.perf_counter() + if not _rollback_after_reject( + gen_batch.model, + gen_batch.prompt_cache, + gdn_states, + accepted=0, + block_size=len(draft_ids) + 1, + ): + if procs is not None: + _trim_token_buffer(gen_batch, len(draft_ids)) + return False + if procs is not None: + _trim_token_buffer(gen_batch, len(draft_ids)) + state.stats.cache_ops_ms += (time.perf_counter() - t0) * 1000 + + emit_id = verify_ids[0] + _mark_used_ngram(state, history, emit_id, n_match, max_entries) + emit_tok = mx.array([emit_id], dtype=mx.uint32) + hidden_at_last_committed = hidden[:, 0:1, :] + _defer_mtp_refresh( + gen_batch, + state, + hidden_at_last_committed, + emit_tok, + prev_buffers[0] if procs is not None else None, + ) + state.queue.append((emit_id, None, "verify")) + state.next_main = emit_tok + return True # --------------------------------------------------------------------------- # Verify cycle: 2-token forward + accept/reject + MTP forward for next draft. # --------------------------------------------------------------------------- + def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: """Run one verify cycle. Populates ``state.queue`` with 1 (reject) or 2 (accept) tokens for upcoming emit calls. Updates ``state.next_main`` and @@ -793,8 +1356,7 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: accept = verify_id == draft_id else: log_accept = ( - verify_accept_lp[0, draft_id].item() - - draft_accept_lp[draft_id].item() + verify_accept_lp[0, draft_id].item() - draft_accept_lp[draft_id].item() ) accept = log_accept >= 0 or random.random() < math.exp(log_accept) state.stats.sample_ms += (time.perf_counter() - t0) * 1000 @@ -834,8 +1396,11 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: # accepted=0 means only the confirmed token (verify position) is kept; # block_size=2 covers both the confirmed and the rejected draft. if not _rollback_after_reject( - gen_batch.model, gen_batch.prompt_cache, gdn_states, - accepted=0, block_size=2, + gen_batch.model, + gen_batch.prompt_cache, + gdn_states, + accepted=0, + block_size=2, ): if procs is not None: _trim_token_buffer(gen_batch, 1) @@ -876,6 +1441,7 @@ def _run_verify_cycle(gen_batch: Any, state: _MtpState) -> None: # Helpers used by the verify cycle. # --------------------------------------------------------------------------- + def _step_mtp( gen_batch: Any, hidden_at_position: Any, @@ -905,9 +1471,7 @@ def _step_mtp( ) mtp_logits_2d = mtp_logits[:, -1, :] if procs is not None and prev_buf is not None: - prev_with_next = mx.concatenate( - [prev_buf, _ensure_uint32(next_main_tok)] - ) + prev_with_next = mx.concatenate([prev_buf, _ensure_uint32(next_main_tok)]) mtp_logits_2d = _apply_processors(procs, prev_with_next, mtp_logits_2d) new_lp = _logprobs(mtp_logits_2d) new_tok = sampler(new_lp) @@ -953,6 +1517,7 @@ def _residual_sample(verify_lp_2d: Any, draft_lp_1d: Any) -> Tuple[int, Any]: # Response builder — mirrors GenerationBatch.next()'s per-sequence epilogue. # --------------------------------------------------------------------------- + def _emit_response( gen_batch: Any, token_id: int, diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 826c55c35..92fd7924f 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -301,6 +301,7 @@ def _patched_generation_batch_step(self): _ckvcache_methods_skipped: list[str] = [] if not hasattr(_CKVCache, "merge"): + @classmethod def _ckvcache_merge_passthrough(cls, caches): if len(caches) == 1: @@ -315,6 +316,7 @@ def _ckvcache_merge_passthrough(cls, caches): _ckvcache_methods_skipped.append("merge") if not hasattr(_CKVCache, "filter"): + def _ckvcache_filter_passthrough(self, batch_indices): try: n = len(batch_indices) @@ -339,6 +341,7 @@ def _ckvcache_filter_passthrough(self, batch_indices): _ckvcache_methods_skipped.append("filter") if not hasattr(_CKVCache, "extract"): + def _ckvcache_extract_passthrough(self, idx): return self @@ -347,6 +350,7 @@ def _ckvcache_extract_passthrough(self, idx): _ckvcache_methods_skipped.append("extract") if not hasattr(_CKVCache, "size"): + def _ckvcache_size(self): return max(0, self.offset - self.start_position) @@ -355,6 +359,7 @@ def _ckvcache_size(self): _ckvcache_methods_skipped.append("size") if not hasattr(_CKVCache, "extend"): + def _ckvcache_extend_passthrough(self, other): if other is None or other.empty(): return @@ -378,8 +383,7 @@ def _ckvcache_extend_passthrough(self, other): # Surface which ones so a regression in Llama-4 batching is visible # to operators without diffing the patch against installed mlx_lm. logger.info( - "ChunkedKVCache patch: methods already present upstream, " - "skipped: %s", + "ChunkedKVCache patch: methods already present upstream, " "skipped: %s", ", ".join(_ckvcache_methods_skipped), ) except ImportError: @@ -1885,7 +1889,11 @@ def _begin_prefill( if hasattr(self.model, "clear_vlm_position_state"): self.model.clear_vlm_position_state() - prompt_cache = existing_cache if existing_cache is not None else make_prompt_cache(self.model) + prompt_cache = ( + existing_cache + if existing_cache is not None + else make_prompt_cache(self.model) + ) block_size = self.config.paged_cache_block_size boundary_enabled = ( @@ -1967,7 +1975,9 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: and total_tokens % state.block_size == 0 and state.emitted_boundaries.get(rid, -1) < total_tokens ): - self._emit_prefill_boundary_snapshot(state.request, state.cache, total_tokens) + self._emit_prefill_boundary_snapshot( + state.request, state.cache, total_tokens + ) state.emitted_boundaries[rid] = total_tokens # Progress callback so the admin UI prefilling list advances during @@ -1978,9 +1988,11 @@ def _step_prefill_chunk(self, state: _PrefillState) -> bool: state.request.request_id, state.tokens_processed, state.total_length - 1, - os.path.basename(self.config.model_name.rstrip("/")) - if self.config.model_name - else "", + ( + os.path.basename(self.config.model_name.rstrip("/")) + if self.config.model_name + else "" + ), ) # Memory monitoring — use max(active, phys_footprint) so MLX cache @@ -2025,7 +2037,9 @@ def _emit_final_boundary_if_needed(self, state: _PrefillState) -> None: and total_tokens % state.block_size == 0 and state.emitted_boundaries.get(rid, -1) < total_tokens ): - self._emit_prefill_boundary_snapshot(state.request, state.cache, total_tokens) + self._emit_prefill_boundary_snapshot( + state.request, state.cache, total_tokens + ) def _insert_prefilled_request( self, @@ -2057,6 +2071,12 @@ def _insert_prefilled_request( if uids: uid = uids[0] + if getattr(self.model, "_omlx_ngram_spec_enabled", False): + prompt_map = getattr(self.model, "_omlx_ngram_prompt_tokens", None) + if prompt_map is None: + prompt_map = {} + self.model._omlx_ngram_prompt_tokens = prompt_map + prompt_map[uid] = list(request.prompt_token_ids or []) self.request_id_to_uid[request.request_id] = uid self.uid_to_request_id[uid] = request.request_id now = time.monotonic() @@ -2077,8 +2097,11 @@ def _insert_prefilled_request( logger.debug( "Scheduled chunked-prefill request %s (uid=%d) " "with %d tokens (%d total)%s", - request.request_id, uid, - len(state.last_token), request.num_prompt_tokens, cache_info, + request.request_id, + uid, + len(state.last_token), + request.num_prompt_tokens, + cache_info, ) def _advance_chunked_prefills( @@ -2154,7 +2177,8 @@ def _advance_chunked_prefills( # Unlikely, but if BG creation fails put request back. logger.error( "BatchGenerator unavailable at chunked-prefill completion " - "for %s; requeueing.", rid + "for %s; requeueing.", + rid, ) still_prefilling.append(request) self._prefill_states[rid] = state @@ -3812,7 +3836,8 @@ def _try_specprefill_scoring(self, request: Request) -> None: spec_extra = { "prompt_tokens": request.num_prompt_tokens, "system_tokens": request.specprefill_system_end, - "conversation_tokens": request.num_prompt_tokens - request.specprefill_system_end, + "conversation_tokens": request.num_prompt_tokens + - request.specprefill_system_end, "cached_tokens": request.cached_tokens, } @@ -4130,7 +4155,12 @@ def has_requests(self) -> bool: Without this, an idle server would never reach the target step and stale buffers would accumulate indefinitely. """ - return bool(self.waiting or self.prefilling or self.running or self._deferred_clear_at is not None) + return bool( + self.waiting + or self.prefilling + or self.running + or self._deferred_clear_at is not None + ) def fail_all_requests(self) -> list[str]: """Remove all running and waiting requests after unrecoverable error. @@ -4480,13 +4510,14 @@ def _check_specprefill_abort(processed: int) -> None: spec_sparse_extra = { "prompt_tokens": request.num_prompt_tokens, "system_tokens": request.specprefill_system_end, - "conversation_tokens": request.num_prompt_tokens - request.specprefill_system_end, + "conversation_tokens": request.num_prompt_tokens + - request.specprefill_system_end, "cached_tokens": request.cached_tokens, "scored_tokens": m_pre, "selected_tokens": n_eff, - "keep_percent": round(n_eff / m_pre * 100) - if m_pre > 0 - else 0, + "keep_percent": ( + round(n_eff / m_pre * 100) if m_pre > 0 else 0 + ), } while sys_arr.size > step: _check_specprefill_abort(sys_processed) @@ -4576,12 +4607,15 @@ def _sparse_progress(processed: int, total: int) -> None: extra={ "scored_tokens": M, "selected_tokens": int(selected.shape[0]), - "keep_percent": round(int(selected.shape[0]) / M * 100) - if M > 0 - else 0, + "keep_percent": ( + round(int(selected.shape[0]) / M * 100) + if M > 0 + else 0 + ), "prompt_tokens": request.num_prompt_tokens, "system_tokens": request.specprefill_system_end, - "conversation_tokens": request.num_prompt_tokens - request.specprefill_system_end, + "conversation_tokens": request.num_prompt_tokens + - request.specprefill_system_end, "cached_tokens": request.cached_tokens, }, ) @@ -4660,7 +4694,9 @@ def _sparse_progress(processed: int, total: int) -> None: ): sm = self._build_state_machine(request) per_row_lps = list(logits_processors) if logits_processors else [] - state = self._begin_prefill(request, tokens_to_process, cache_to_use) + state = self._begin_prefill( + request, tokens_to_process, cache_to_use + ) state.sampler = sampler state.sm = sm state.per_row_lps = per_row_lps @@ -4784,6 +4820,12 @@ def _sparse_progress(processed: int, total: int) -> None: if uids: uid = uids[0] + if getattr(self.model, "_omlx_ngram_spec_enabled", False): + prompt_map = getattr(self.model, "_omlx_ngram_prompt_tokens", None) + if prompt_map is None: + prompt_map = {} + self.model._omlx_ngram_prompt_tokens = prompt_map + prompt_map[uid] = list(request.prompt_token_ids or []) self.request_id_to_uid[request.request_id] = uid self.uid_to_request_id[uid] = request.request_id now = time.monotonic() @@ -5205,6 +5247,9 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: # fall back to immediate remove for back-compat behavior. if request_id in self.request_id_to_uid: uid = self.request_id_to_uid[request_id] + prompt_map = getattr(self.model, "_omlx_ngram_prompt_tokens", None) + if prompt_map is not None: + prompt_map.pop(uid, None) if store_future is not None: self._pending_async_removes.append((uid, request_id, store_future)) else: @@ -5457,7 +5502,9 @@ def step(self) -> SchedulerOutput: # Run generation step if we have running requests. # Use next_generated() which returns only GenerationBatch.Response # objects (prefill is handled externally before insert). - if (self.batch_generator is not None or self._vlm_mtp_active) and self.running: + if ( + self.batch_generator is not None or self._vlm_mtp_active + ) and self.running: if self.batch_generator is not None: responses = list(self.batch_generator.next_generated()) else: @@ -5492,10 +5539,9 @@ def step(self) -> SchedulerOutput: # there is no race window. Decode-only path — # next_generated() returns nothing during prefill, so # we never disrupt prefill activation buffers. - self._tokens_since_clear_cache = ( - getattr(self, "_tokens_since_clear_cache", 0) - + len(responses) - ) + self._tokens_since_clear_cache = getattr( + self, "_tokens_since_clear_cache", 0 + ) + len(responses) if self._tokens_since_clear_cache >= 1024: _sync_and_clear_cache() self._tokens_since_clear_cache = 0 diff --git a/tests/test_mlx_lm_mtp_patch.py b/tests/test_mlx_lm_mtp_patch.py index 996d1274f..231bf6c72 100644 --- a/tests/test_mlx_lm_mtp_patch.py +++ b/tests/test_mlx_lm_mtp_patch.py @@ -25,6 +25,7 @@ # Patch orchestrator + sub-modules # --------------------------------------------------------------------------- + class TestApplyOrchestrator: def test_apply_idempotent(self): from omlx.patches.mlx_lm_mtp import apply_mlx_lm_mtp_patch @@ -173,6 +174,7 @@ def _apply(self): if not qwen35_model.apply(): pytest.skip("qwen35_model patch refused to apply") from omlx.patches.mlx_lm_mtp.qwen35_model import _patch_qwen3_5_moe + _patch_qwen3_5_moe() @pytest.fixture() @@ -278,7 +280,9 @@ def test_skip_when_base_patch_not_applied(self, monkeypatch): from omlx.patches.mlx_lm_mtp import deepseek_v4_model # Simulate the base patch not having run by removing the module. - monkeypatch.setitem(__import__("sys").modules, "mlx_lm.models.deepseek_v4", None) + monkeypatch.setitem( + __import__("sys").modules, "mlx_lm.models.deepseek_v4", None + ) # Reset the module-level _PATCHED flag so apply() actually runs the # gating check rather than short-circuiting on idempotency. monkeypatch.setattr(deepseek_v4_model, "_PATCHED", False) @@ -297,7 +301,11 @@ def test_apply_with_base_patch_registers_mtp_block(self): from omlx.patches.deepseek_v4 import apply_deepseek_v4_patch except ImportError: pytest.skip("omlx.patches.deepseek_v4 not importable") - if not apply_deepseek_v4_patch(): + try: + applied_base = apply_deepseek_v4_patch() + except ImportError as exc: + pytest.skip(f"DeepSeek-V4 base patch dependency unavailable: {exc}") + if not applied_base: pytest.skip("DeepSeek-V4 base patch refused to apply in this env") from omlx.patches.mlx_lm_mtp import deepseek_v4_model @@ -365,24 +373,19 @@ def __init__(self, model, uids): # (e.g. VLM runtime patches attach unconditionally so weight # load matches, while inference-time MTP stays disabled). mlx_lm_mtp.set_mtp_active(False) - assert ( - _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is False - ) + assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is False mlx_lm_mtp.set_mtp_active(True) # Non-MTP model never triggers the MTP path. assert _is_mtp_eligible(_GenBatch(_NonMtpModel(), uids=[1])) is False # Has mtp_forward but no attached head → still off. assert ( - _is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1])) - is False + _is_mtp_eligible(_GenBatch(_MtpModelWithoutHead(), uids=[1])) is False ) # Has both method and head + batch=1 + flag on → triggers the path. assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1])) is True # MTP model with batch=2 falls back to standard step. - assert ( - _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False - ) + assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[1, 2])) is False # Empty batch never triggers. assert _is_mtp_eligible(_GenBatch(_MtpModel(), uids=[])) is False finally: @@ -393,6 +396,7 @@ def __init__(self, model, uids): # ModelSettings — mtp_enabled field + mutual exclusion # --------------------------------------------------------------------------- + class TestModelSettingsMtp: def test_default_mtp_disabled(self): s = ModelSettings() @@ -403,6 +407,32 @@ def test_mtp_enabled_roundtrip(self): restored = ModelSettings.from_dict(original.to_dict()) assert restored.mtp_enabled is True + def test_ngram_settings_roundtrip(self): + original = ModelSettings( + mtp_enabled=True, + ngram_spec_enabled=True, + ngram_spec_n_match=4, + ngram_spec_draft_min=2, + ngram_spec_draft_max=6, + ngram_spec_min_count=3, + ngram_spec_min_confidence=0.75, + ngram_spec_max_entries=128, + ngram_spec_mtp_fallback=False, + ngram_spec_mtp_adaptive=False, + ngram_spec_mtp_min_cycles=4, + ngram_spec_mtp_min_accept_rate=0.25, + ) + restored = ModelSettings.from_dict(original.to_dict()) + assert restored.ngram_spec_enabled is True + assert restored.ngram_spec_n_match == 4 + assert restored.ngram_spec_min_count == 3 + assert restored.ngram_spec_min_confidence == 0.75 + assert restored.ngram_spec_max_entries == 128 + assert restored.ngram_spec_mtp_fallback is False + assert restored.ngram_spec_mtp_adaptive is False + assert restored.ngram_spec_mtp_min_cycles == 4 + assert restored.ngram_spec_mtp_min_accept_rate == 0.25 + def test_legacy_settings_dict_defaults_mtp_off(self): s = ModelSettings.from_dict({"display_name": "qwen3.6"}) assert s.mtp_enabled is False @@ -422,11 +452,95 @@ def test_mtp_with_specprefill_allowed(self): assert s.mtp_enabled is True assert s.specprefill_enabled is True + def test_ngram_requires_repeated_key_before_indexing(self): + from omlx.patches.mlx_lm_mtp.batch_generator import ( + _MtpState, + _update_ngram_index, + ) + + state = _MtpState() + history = [1, 2, 3, 1, 2, 3, 1, 2, 4] + + _update_ngram_index( + state, + history, + n_match=2, + min_count=2, + min_confidence=0.5, + max_entries=8, + ) + + assert state.ngram_index[(1, 2)] == 3 + assert (2, 4) not in state.ngram_index + + def test_ngram_index_respects_max_entries(self): + from omlx.patches.mlx_lm_mtp.batch_generator import ( + _MtpState, + _update_ngram_index, + ) + + state = _MtpState() + history = [1, 9, 1, 9, 2, 9, 2, 9, 3, 9, 3, 9] + + _update_ngram_index( + state, + history, + n_match=1, + min_count=2, + min_confidence=0.5, + max_entries=2, + ) + + assert len(state.ngram_index) == 2 + + def test_ngram_index_requires_confident_follower(self): + from omlx.patches.mlx_lm_mtp.batch_generator import ( + _MtpState, + _update_ngram_index, + ) + + state = _MtpState() + history = [1, 2, 3, 1, 2, 4, 1, 2, 3, 1, 2, 4] + + _update_ngram_index( + state, + history, + n_match=2, + min_count=2, + min_confidence=0.75, + max_entries=8, + ) + + assert (1, 2) not in state.ngram_index + + def test_ngram_draft_prioritizes_used_over_frequency(self): + from omlx.patches.mlx_lm_mtp.batch_generator import ( + _MtpState, + _draft_from_ngram, + ) + + state = _MtpState() + state.ngram_index[(1, 2)] = 3 + state.ngram_counts[(1, 2)] = {3: 10} + state.ngram_used[(1, 2)] = 4 + + draft = _draft_from_ngram( + state, + history=[1, 2], + n_match=2, + min_count=3, + min_confidence=0.8, + draft_max=1, + ) + + assert draft == [4] + # --------------------------------------------------------------------------- # utils.model_loading — compatibility helpers + dispatch # --------------------------------------------------------------------------- + class TestMtpCompatibilityHelpers: def test_has_mtp_heads_top_level_field(self): assert _has_mtp_heads({"mtp_num_hidden_layers": 1}) is True @@ -435,9 +549,7 @@ def test_has_mtp_heads_nextn_field(self): assert _has_mtp_heads({"num_nextn_predict_layers": 2}) is True def test_has_mtp_heads_text_config_field(self): - assert ( - _has_mtp_heads({"text_config": {"mtp_num_hidden_layers": 1}}) is True - ) + assert _has_mtp_heads({"text_config": {"mtp_num_hidden_layers": 1}}) is True def test_has_mtp_heads_zero_is_false(self): assert _has_mtp_heads({"mtp_num_hidden_layers": 0}) is False @@ -501,10 +613,13 @@ def test_dispatch_skips_when_incompatible_model(self, tmp_path, caplog): str(tmp_path), model_settings=ModelSettings(mtp_enabled=True) ) # The skip path should log a warning so the user sees why MTP was inactive. - assert any( - "MTP path will be inactive" in record.getMessage() - for record in caplog.records - ) or True # logger.warning may be filtered by pytest logging level + assert ( + any( + "MTP path will be inactive" in record.getMessage() + for record in caplog.records + ) + or True + ) # logger.warning may be filtered by pytest logging level def test_dispatch_handles_missing_config(self, tmp_path): # No config.json at all — function must not raise.