Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions mtplx/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
"MTPLX_LONG_CONTEXT_MTP_DEPTH_POLICY": "auto",
"MTPLX_LONG_CONTEXT_MTP_DEPTH_THRESHOLD": "98304",
"MTPLX_LONG_CONTEXT_MTP_DEPTH": "2",
# Aggressive context-aware ladder. The legacy threshold/cap above stays
# in place as a fallback when the ladder is explicitly disabled (env var
# set to ""). See `resolve_long_context_mtp_depth` in profiles.py and
# the rationale block above DEFAULT_MTP_LONG_CONTEXT_LADDER.
"MTPLX_MTP_LONG_CONTEXT_LADDER": "16384:2,24576:1,30720:0",
"MTPLX_MTP_HISTORY_POLICY": "auto",
"MTPLX_MTP_HISTORY_LAST_WINDOW": "8192",
"MTPLX_MTP_HISTORY_LAST_WINDOW_THRESHOLD": "16384",
Expand Down Expand Up @@ -126,6 +131,80 @@ def _env_int(
return default


# --- MTP long-context depth ladder -------------------------------------------
#
# Rationale: MTP draft acceptance on Qwen3-family models collapses as context
# grows. Independent vLLM tracking issues confirm the same trend (#35387 shows
# a 76% latency regression at 256K context; #40756 reports crashes at 26K;
# #36872 documents acceptance dropping 61% -> 0.9% -> 0% across consecutive
# turns). Our own MTPLX tracking issue #49 reproduces the curve at depth=3
# (47% -> 33% -> 27%). The vLLM-recommended setting for Qwen3.6-35B is
# `num_speculative_tokens=2`, not 3.
#
# Cheapest mitigation: lower the speculative depth as the prompt grows. This
# ladder caps requested depth based on prompt token count, with a final step
# that disables MTP entirely (depth=0) once the context exceeds the largest
# threshold. Both the existing single-step `auto` policy and the new ladder
# coexist - the ladder takes precedence when configured.
#
# Default ladder ("16384:2,24576:1,30720:0"):
# <16K -> requested depth (default 3)
# 16384 .. 24575 -> min(requested, 2)
# 24576 .. 30719 -> min(requested, 1)
# >= 30720 -> 0 (MTP off)
DEFAULT_MTP_LONG_CONTEXT_LADDER = "16384:2,24576:1,30720:0"


def _parse_ladder(
raw: str | None,
) -> tuple[tuple[int, int], ...]:
"""Parse a comma-separated `threshold:capped_depth` ladder spec.

Returns the ladder sorted ascending by threshold. Returns an empty tuple
on an explicit empty string (caller treats this as "ladder disabled").
Malformed entries are skipped silently so a bad env var never bricks
the server. Negative thresholds clamp to 0; capped_depth clamps to >= 0
(0 means "MTP off at this rung").
"""

if raw is None:
return ()
text = raw.strip()
if not text:
return ()
pairs: list[tuple[int, int]] = []
for chunk in text.split(","):
chunk = chunk.strip()
if not chunk:
continue
if ":" not in chunk:
continue
thr_str, cap_str = chunk.split(":", 1)
try:
threshold = max(0, int(thr_str.strip()))
cap = max(0, int(cap_str.strip()))
except (TypeError, ValueError):
continue
pairs.append((threshold, cap))
pairs.sort(key=lambda kv: kv[0])
return tuple(pairs)


def _ladder_step_for_prompt(
ladder: tuple[tuple[int, int], ...],
prompt_tokens: int,
) -> tuple[int, int] | None:
"""Return the highest-threshold rung whose threshold <= prompt_tokens."""

selected: tuple[int, int] | None = None
for threshold, cap in ladder:
if int(prompt_tokens) >= int(threshold):
selected = (threshold, cap)
else:
break
return selected


def resolve_long_context_mtp_depth(
*,
prompt_tokens: int,
Expand All @@ -138,6 +217,10 @@ def resolve_long_context_mtp_depth(
Depth 3 is still the default because it wins at short and mid context. On the
M5 Max 128k path, depth 2 recovered decode while preserving exact speculative
sampling. This helper keeps that product policy explicit and observable.

If MTPLX_MTP_LONG_CONTEXT_LADDER is set (or defaults are used and policy
is "auto"), a multi-step ladder takes precedence over the single-threshold
cap. See DEFAULT_MTP_LONG_CONTEXT_LADDER for the default schedule.
"""

source = os.environ if env is None else env
Expand All @@ -157,6 +240,16 @@ def resolve_long_context_mtp_depth(
1,
_env_int(source, "MTPLX_LONG_CONTEXT_MTP_DEPTH", default=2),
)
ladder_raw = source.get("MTPLX_MTP_LONG_CONTEXT_LADDER")
# If env var is unset, fall back to the default ladder. If env var is
# set to an empty string, treat as "ladder explicitly disabled" and
# only use the legacy single-step gate.
if ladder_raw is None:
ladder = _parse_ladder(DEFAULT_MTP_LONG_CONTEXT_LADDER)
ladder_source = "default"
else:
ladder = _parse_ladder(ladder_raw)
ladder_source = "env" if ladder else "env_empty"
details: dict[str, object] = {
"policy": policy,
"prompt_tokens": int(prompt_tokens),
Expand All @@ -166,6 +259,8 @@ def resolve_long_context_mtp_depth(
"min_depth": int(floor),
"active": False,
"reason": "disabled",
"ladder": [list(rung) for rung in ladder],
"ladder_source": ladder_source,
}
if policy in {"", "0", "off", "false", "none"}:
details["effective_depth"] = int(requested)
Expand All @@ -174,6 +269,34 @@ def resolve_long_context_mtp_depth(
details["reason"] = "unknown_policy"
details["effective_depth"] = int(requested)
return requested, details

# Ladder takes precedence over the single-step gate when configured.
rung = _ladder_step_for_prompt(ladder, int(prompt_tokens)) if ladder else None
if rung is not None:
rung_threshold, rung_cap = rung
# rung_cap of 0 means "MTP off at this rung". The downstream generator
# path requires depth >= 1, so we clamp to max(rung_cap, floor) when
# rung_cap > 0, and only return 0 when the caller can handle that
# (server.openai applies the ladder before dispatch and switches to AR).
if rung_cap <= 0:
effective = 0
else:
effective = min(requested, max(rung_cap, floor))
details["ladder_threshold"] = int(rung_threshold)
details["ladder_cap_depth"] = int(rung_cap)
details["effective_depth"] = int(effective)
if effective < requested:
details["active"] = True
details["reason"] = (
f"ladder_step_{rung_threshold}"
if rung_cap > 0
else f"ladder_step_{rung_threshold}_mtp_off"
)
else:
details["reason"] = "ladder_within_cap"
return effective, details

# Fall back to the legacy single-step gate.
if int(prompt_tokens) < threshold:
details["reason"] = "below_threshold"
details["effective_depth"] = int(requested)
Expand Down
92 changes: 90 additions & 2 deletions mtplx/server/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
apply_profile_env,
get_profile,
profile_env_status,
resolve_long_context_mtp_depth,
)
from mtplx.draft_lm_head import _install_draft_lm_head
from mtplx.server_urls import bind_label, is_wildcard_bind, local_url_for_bind
Expand Down Expand Up @@ -2882,6 +2883,52 @@ def _request_depth_for_generation(
return depth


def _apply_long_context_ladder(
*,
prompt_tokens: int,
requested_depth: int,
generation_mode: str,
) -> tuple[int, str, dict[str, Any]]:
"""Apply the context-length-based MTP depth ladder at request time.

Returns ``(effective_depth, effective_generation_mode, details)``. When the
ladder's selected rung specifies a capped depth of 0 (MTP off), the mode
is switched to ``"ar"`` and the returned depth is 0 so the downstream AR
path can run safely. The details dict is shaped the same as the one
produced by :func:`resolve_long_context_mtp_depth` for the in-generation
gate, so it can be surfaced as ``long_context_mtp_depth_policy`` in the
request metrics envelope.

This is the FIRST layer of the gate: the request-time ladder. The
in-generation single-step gate inside ``generate_mtpk`` (also using
``resolve_long_context_mtp_depth``) is the SECOND layer and still runs.
Both layers are safe to stack because the second layer can only further
reduce depth, not increase it.

See ``mtplx.profiles.DEFAULT_MTP_LONG_CONTEXT_LADDER`` for the default
schedule and the rationale (vLLM #35387 / #36872 / #40756, MTPLX #49).
"""

if generation_mode != "mtp":
return int(requested_depth), generation_mode, {}
effective, details = resolve_long_context_mtp_depth(
prompt_tokens=int(prompt_tokens),
requested_depth=max(1, int(requested_depth)),
)
effective_mode = generation_mode
if effective <= 0:
# Ladder asked for MTP off at this context size. The downstream
# generator requires depth >= 1 in mtp mode, so we switch the request
# to AR. The metrics envelope still records the ladder details so
# observers can see why the mode flipped.
effective_mode = "ar"
effective = 0
details["mode_switched_to_ar"] = True
else:
details["mode_switched_to_ar"] = False
return int(effective), effective_mode, details


def _token_window_rate(token_times: list[float], window: int) -> float | None:
if len(token_times) < 2:
return None
Expand Down Expand Up @@ -3241,6 +3288,7 @@ def _request_observability(
session_source: str | None,
request_generation_mode: str,
request_depth: int,
long_context_ladder_details: dict[str, Any] | None = None,
) -> dict[str, Any]:
declared_extra_keys = [
key
Expand Down Expand Up @@ -3292,6 +3340,7 @@ def _request_observability(
"request_depth": int(request_depth),
"request_last_user_preview": user_texts[-1][:180] if user_texts else None,
"request_last_user_chars": len(user_texts[-1]) if user_texts else 0,
"request_long_context_ladder": dict(long_context_ladder_details or {}),
}


Expand Down Expand Up @@ -4406,15 +4455,38 @@ def record_tokens(new_tokens: list[int]) -> None:
envelope["requested_mtp_depth"] = (
requested_mtp_depth if effective_mode == "mtp" else 0
)
envelope["long_context_mtp_depth_policy"] = (
# Surface the request-time ladder (computed in _apply_long_context_ladder)
# whenever it was set, even if generation ran in AR mode because the
# ladder forced MTP off. The in-generation gate writes its own copy
# of `long_context_mtp_depth_policy` into stats only when MTP runs and
# the gate triggers; the request-time ladder runs unconditionally and
# is the more informative one for chat-monitor.
in_generation_policy = (
(stats.get("long_context_mtp_depth_policy") or {})
if effective_mode == "mtp"
else {}
)
ladder_details_from_request = (
(request_observability or {}).get("request_long_context_ladder")
or {}
)
if ladder_details_from_request:
merged_policy: dict[str, Any] = dict(ladder_details_from_request)
# If the in-generation gate also added details, let them override
# (they have the most accurate post-resolution view).
if in_generation_policy:
merged_policy.update(in_generation_policy)
envelope["long_context_mtp_depth_policy"] = merged_policy
else:
envelope["long_context_mtp_depth_policy"] = in_generation_policy
if effective_mode == "ar":
envelope["mtp_depth"] = 0
envelope["requested_mtp_depth"] = 0
envelope["long_context_mtp_depth_policy"] = {}
# Keep the ladder policy visible even in AR mode when the ladder
# is the reason MTP got disabled - clearing it here would hide
# the "why" from observers.
if not ladder_details_from_request:
envelope["long_context_mtp_depth_policy"] = {}
envelope["verify_calls"] = 0
envelope["verify_time_s"] = 0.0
envelope["accepted_by_depth"] = []
Expand Down Expand Up @@ -6510,6 +6582,9 @@ def health() -> dict[str, Any]:
"MTPLX_LONG_CONTEXT_MTP_DEPTH_THRESHOLD"
),
"long_context_mtp_depth": os.environ.get("MTPLX_LONG_CONTEXT_MTP_DEPTH"),
"mtp_long_context_ladder": os.environ.get(
"MTPLX_MTP_LONG_CONTEXT_LADDER"
),
"mtp_position_mode": os.environ.get("MTPLX_MTP_POSITION_MODE"),
"mtp_position_cap": os.environ.get("MTPLX_MTP_POSITION_CAP"),
"mtp_position_period": os.environ.get("MTPLX_MTP_POSITION_PERIOD"),
Expand Down Expand Up @@ -6736,6 +6811,13 @@ async def chat_completions(
tools=tool_specs if tools_active else None,
template_observability=template_observability,
)
request_depth, request_generation_mode, long_context_ladder_details = (
_apply_long_context_ladder(
prompt_tokens=len(prompt_ids),
requested_depth=request_depth,
generation_mode=request_generation_mode,
)
)
current_system_hash = system_prompt_hash(request.messages)
if current_system_hash is not None and not background:
state.main_system_prompt_hash = current_system_hash
Expand Down Expand Up @@ -6780,6 +6862,7 @@ async def chat_completions(
session_source=session_source,
request_generation_mode=request_generation_mode,
request_depth=request_depth,
long_context_ladder_details=long_context_ladder_details,
)
request_observability.update(template_observability)
if template_observability.get("tool_template_fallback"):
Expand Down Expand Up @@ -7758,6 +7841,11 @@ async def completions(request: CompletionRequest) -> Any:
request,
generation_mode=request_generation_mode,
)
request_depth, request_generation_mode, _ = _apply_long_context_ladder(
prompt_tokens=len(prompt_ids),
requested_depth=request_depth,
generation_mode=request_generation_mode,
)
generated = await asyncio.wrap_future(
_submit_foreground_model_work(
state,
Expand Down
Loading