Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions docs/ngram-mtp-speculative-decoding-study.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# N-gram + MTP Speculative Decoding Notes

## Goal

Test whether n-gram speculation can make long roleplay generation faster in oMLX, and whether it should be combined with MTP.

## Short Answer

Yes, it helps on long repeated conversations.

Best current routing:

```text
1. Try short used-priority n-gram draft.
2. If n-gram misses, try MTP fallback.
3. If MTP fallback is not accepting enough, disable it for the rest of the request.
4. Fall back to plain target greedy when needed.
```

## Current Recommended Settings

```text
ngram_spec_enabled = true
ngram_spec_n_match = 4
ngram_spec_draft_min = 1
ngram_spec_draft_max = 2
ngram_spec_min_count = 3
ngram_spec_min_confidence = 0.8
ngram_spec_max_entries = 2048
ngram_spec_mtp_fallback = true
ngram_spec_mtp_adaptive = true
ngram_spec_mtp_min_cycles = 8
ngram_spec_mtp_min_accept_rate = 0.5
```

## Key Ideas

### 1. N-gram should be short

Long n-gram drafts caused problems.

On the 40-turn roleplay test:

| `draft_max` | Correct | Speed |
|---:|---:|---:|
| 1 | yes | 61.70 tok/s |
| 2 | yes | 72.50 tok/s |
| 4 | no | diverged |
| 8 | no | diverged |

So the default should stay small:

```text
ngram_spec_draft_max = 2
```

### 2. Used n-grams should win over frequent n-grams

Prompt frequency alone is not enough.

Example:

```text
Key: Archive keeper: The
Frequent prompt continuation: eastern aisle...
Current live continuation: western stair...
```

If the model already used `western` in this generation, that should be prioritized over the more frequent prompt branch.

So the implementation uses:

```text
used n-gram table first
frequency table second
```

### 3. MTP helps, but not everywhere

MTP-only works on many prompts.

Example 40-turn run:

```text
Plain greedy: 46.67 tok/s
MTP-only: 53.09 tok/s
```

But n-gram helps more on repeated conversations:

```text
N-gram target fallback: 66.15 tok/s
```

MTP is best used as a fallback after n-gram misses, not as the main strategy for repeated roleplay text.

### 4. Adaptive MTP fallback is best

MTP fallback can help, but if it starts rejecting too much, it becomes overhead.

So we track MTP fallback accept rate per request.

If accept rate is too low after enough cycles, MTP fallback is disabled for the rest of the request.

## Benchmark Results

### 40-turn roleplay benchmark

Generation length: 320 tokens.

| Path | Correct | wall tok/s | decode tok/s |
|---|---:|---:|---:|
| Plain greedy | yes | 48.72 | 62.67 |
| N-gram + target fallback | yes | 67.44 | 101.33 |
| N-gram + MTP fallback | yes | 68.85 | 103.87 |
| N-gram + adaptive MTP fallback | yes | 69.61 | 104.35 |

Best result:

```text
N-gram + adaptive MTP fallback
69.61 tok/s wall throughput
104.35 tok/s decode throughput
```

### Prompt-shape matrix

| Case | Best path | Result |
|---|---|---|
| Low-repeat prose | MTP-only | small gain |
| Short repeated oath | N-gram | small gain |
| 40-turn conversation | N-gram + adaptive MTP | large gain |
| Branch-heavy repeated prompt | unsafe | speculative paths diverged |

## Example N-gram Suggestions

From the 40-turn roleplay prompt:

| Key | Suggested draft |
|---|---|
| `remember the river,` | ` the tower` |
| `the river, the` | ` tower,` |
| `river, the tower` | `, and` |
| `the tower, and` | ` the name` |
| `and the name beneath` | ` the glass` |
| `Mira: The` | ` river mark` |
| `The river mark is` | ` still cold` |

Replay stats:

```text
315 n-gram suggestion events
313 full matches
629 drafted tokens
627 accepted-prefix tokens
```

## N-gram vs MTP Overlap

Diagnostic test on the 40-turn prompt:

```text
overlap events: 110
first-token agree: 52
first-token disagree: 58
agreement rate: 47.3%
```

Meaning:

- N-gram and MTP are not redundant.
- N-gram is better at exact repeated text.
- MTP is better as a local model-based fallback.

Example agreement:

| Key | N-gram | MTP |
|---|---|---|
| `tower, and the` | ` name beneath` | ` name` |
| `The river mark is` | ` still cold` | ` still` |

Example disagreement:

| Key | N-gram | MTP |
|---|---|---|
| `Mira: The` | ` river mark` | ` name` |
| `Archive keeper:` | ` Name the` | ` The` |

## Remaining Risks

The branch-heavy repeated prompt still diverged under speculative modes.

So this is not yet a universal production-safe optimization for every prompt shape.

Safe target use case:

```text
long repeated conversation / roleplay structure
greedy decoding
short n-gram drafts
adaptive MTP fallback
```

## Conclusion

N-gram speculation is useful for long roleplay conversations because the text has repeated structure.

MTP also works, but it is better as an adaptive fallback.

The best current policy is:

```text
short used-priority n-gram first
adaptive MTP fallback second
plain target greedy fallback when MTP stops helping
```
51 changes: 46 additions & 5 deletions omlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,42 @@ def _load_model_sync():

self._model = apply_post_load_transforms(self._model, self._model_settings)

if self._model_settings is not None:
self._model._omlx_ngram_spec_enabled = bool(
getattr(self._model_settings, "ngram_spec_enabled", False)
)
self._model._omlx_ngram_spec_n_match = int(
getattr(self._model_settings, "ngram_spec_n_match", 4) or 4
)
self._model._omlx_ngram_spec_draft_min = int(
getattr(self._model_settings, "ngram_spec_draft_min", 1) or 1
)
self._model._omlx_ngram_spec_draft_max = int(
getattr(self._model_settings, "ngram_spec_draft_max", 2) or 2
)
self._model._omlx_ngram_spec_min_count = int(
getattr(self._model_settings, "ngram_spec_min_count", 3) or 3
)
self._model._omlx_ngram_spec_min_confidence = float(
getattr(self._model_settings, "ngram_spec_min_confidence", 0.8) or 0.8
)
self._model._omlx_ngram_spec_max_entries = int(
getattr(self._model_settings, "ngram_spec_max_entries", 2048) or 2048
)
self._model._omlx_ngram_spec_mtp_fallback = bool(
getattr(self._model_settings, "ngram_spec_mtp_fallback", True)
)
self._model._omlx_ngram_spec_mtp_adaptive = bool(
getattr(self._model_settings, "ngram_spec_mtp_adaptive", True)
)
self._model._omlx_ngram_spec_mtp_min_cycles = int(
getattr(self._model_settings, "ngram_spec_mtp_min_cycles", 8) or 8
)
self._model._omlx_ngram_spec_mtp_min_accept_rate = float(
getattr(self._model_settings, "ngram_spec_mtp_min_accept_rate", 0.5)
or 0.5
)

# TurboQuant KV cache: patch attention and set kv_bits on scheduler
if self._model_settings is not None:
tq_enabled = getattr(self._model_settings, "turboquant_kv_enabled", False)
Expand Down Expand Up @@ -440,7 +476,8 @@ def count_chat_tokens(
messages = self._preprocess_messages(messages)
template_tools = convert_tools_for_template(tools) if tools else None
prompt = self._apply_chat_template(
messages, template_tools,
messages,
template_tools,
chat_template_kwargs=chat_template_kwargs,
is_partial=is_partial,
)
Expand Down Expand Up @@ -675,8 +712,10 @@ async def chat(
ct_kwargs = kwargs.pop("chat_template_kwargs", None)
partial = kwargs.pop("is_partial", None)
prompt = self._apply_chat_template(
messages, template_tools,
chat_template_kwargs=ct_kwargs, is_partial=partial,
messages,
template_tools,
chat_template_kwargs=ct_kwargs,
is_partial=partial,
)

return await self.generate(
Expand Down Expand Up @@ -735,8 +774,10 @@ async def stream_chat(
ct_kwargs = kwargs.pop("chat_template_kwargs", None)
partial = kwargs.pop("is_partial", None)
prompt = self._apply_chat_template(
messages, template_tools,
chat_template_kwargs=ct_kwargs, is_partial=partial,
messages,
template_tools,
chat_template_kwargs=ct_kwargs,
is_partial=partial,
)

# SpecPrefill: compute system prompt token count for protection.
Expand Down
37 changes: 25 additions & 12 deletions omlx/model_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@
"vlm_mtp_enabled",
"vlm_mtp_draft_model",
"vlm_mtp_draft_block_size",
"ngram_spec_enabled",
"ngram_spec_n_match",
"ngram_spec_draft_min",
"ngram_spec_draft_max",
"ngram_spec_min_count",
"ngram_spec_min_confidence",
"ngram_spec_max_entries",
"ngram_spec_mtp_fallback",
"ngram_spec_mtp_adaptive",
"ngram_spec_mtp_min_cycles",
"ngram_spec_mtp_min_accept_rate",
"specprefill_enabled",
"specprefill_draft_model",
"specprefill_keep_pct",
Expand All @@ -69,18 +80,20 @@
)

# Excluded — never stored in a profile or template.
EXCLUDED_FROM_PROFILES = frozenset({
"is_pinned",
"is_default",
"display_name",
"description",
"model_alias",
"model_type_override",
"active_profile_name",
"ttl_seconds",
# Security flag must be explicit per model — never propagated via profiles.
"trust_remote_code",
})
EXCLUDED_FROM_PROFILES = frozenset(
{
"is_pinned",
"is_default",
"display_name",
"description",
"model_alias",
"model_type_override",
"active_profile_name",
"ttl_seconds",
# Security flag must be explicit per model — never propagated via profiles.
"trust_remote_code",
}
)


def filter_universal_fields(data: dict[str, Any]) -> dict[str, Any]:
Expand Down
Loading