Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mlx_lm/SERVER.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ curl localhost:8080/v1/chat/completions \
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.

- `timings`: Server-side timing measurements, following the shape used by
llama.cpp and several open-source clients. Times are measured around
the generation service only (no network or serialization) and include
some internal scheduling overhead, so treat them as approximate.
- `prompt_n`: Prompt tokens processed (excludes cached tokens).
- `prompt_ms`: Time to first generated token, in milliseconds.
- `prompt_per_second`: `prompt_n / (prompt_ms / 1000)`, or `0` if
`prompt_ms` is `0`.
- `predicted_n`: Tokens generated.
- `predicted_ms`: Time from first to last generated token, in
milliseconds. `0` when fewer than two tokens are generated.
- `predicted_per_second`: `predicted_n / (predicted_ms / 1000)`, or
`0` if `predicted_ms` is `0`.

For streaming requests, `timings` rides on the final usage chunk and
requires `stream_options.include_usage`.

### List Models

Use the `v1/models` endpoint to list available models:
Expand Down
61 changes: 55 additions & 6 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ class GenerationContext:

prompt: List[int]
prompt_cache_count: int = -1
prompt_start_at: Optional[float] = None
prompt_end_at: Optional[float] = None
decode_end_at: Optional[float] = None

_should_stop: bool = False

Expand Down Expand Up @@ -437,6 +440,25 @@ def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]:
)


def _make_timings(ctx, prompt_n: int, predicted_n: int, cache_n: int) -> Dict[str, Any]:
def elapsed(start, end):
if start is None or end is None or end <= start:
return 0
return end - start

prompt_s = elapsed(ctx.prompt_start_at, ctx.prompt_end_at)
predicted_s = elapsed(ctx.prompt_end_at, ctx.decode_end_at)
return {
"prompt_n": prompt_n,
"prompt_ms": prompt_s * 1000,
"prompt_per_second": (prompt_n / prompt_s) if prompt_s else 0,
"predicted_n": predicted_n,
"predicted_ms": predicted_s * 1000,
"predicted_per_second": (predicted_n / predicted_s) if predicted_s else 0,
"cache_n": cache_n,
}


class ResponseGenerator:
def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache):
self.model_provider = model_provider
Expand Down Expand Up @@ -850,15 +872,20 @@ def get_next_request(timeout=None):

uids_to_remove = []
for _ in self._time_budget:
tic = time.perf_counter()
prompt_responses, gen_responses = batch_generator.next()
toc = time.perf_counter()
if not prompt_responses and not gen_responses:
break

# Progress report for prompt processing
for r in prompt_responses:
result = batch_results[r.uid]
ctx = result["ctx"]
if ctx.prompt_start_at is None:
ctx.prompt_start_at = tic
result["rqueue"].put(r.progress)
if result["ctx"]._should_stop:
if ctx._should_stop:
uids_to_remove.append(r.uid)

# Save the caches at end of segments
Expand All @@ -881,6 +908,9 @@ def get_next_request(timeout=None):

for r in gen_responses:
result = batch_results[r.uid]
ctx = result["ctx"]
if ctx.prompt_end_at is None:
ctx.prompt_end_at = toc
result["detokenizer"].add_token(r.token)
result["rqueue"].put(
Response(
Expand All @@ -899,6 +929,7 @@ def get_next_request(timeout=None):
)

if r.finish_reason is not None:
ctx.decode_end_at = toc
result["rqueue"].put(None)
self.prompt_cache.insert_cache(
current_model_key,
Expand All @@ -908,7 +939,7 @@ def get_next_request(timeout=None):
)
del batch_results[r.uid]

if result["ctx"]._should_stop:
if ctx._should_stop:
uids_to_remove.append(r.uid)

uids_to_remove = self._share_object(uids_to_remove)
Expand Down Expand Up @@ -973,6 +1004,7 @@ def progress(tokens_processed, tokens_total):
cache += make_prompt_cache(self.model_provider.draft_model)

# Process the prompt and generate tokens
ctx.prompt_start_at = time.perf_counter()
for gen in stream_generate(
model=model,
tokenizer=tokenizer,
Expand All @@ -986,6 +1018,8 @@ def progress(tokens_processed, tokens_total):
prompt_progress_callback=progress,
prefill_step_size=self.cli_args.prefill_step_size,
):
if ctx.prompt_end_at is None:
ctx.prompt_end_at = time.perf_counter()
finish_reason = gen.finish_reason
sm_state, match_sequence, current_state = sm.match(sm_state, gen.token)
if match_sequence is not None and current_state is None:
Expand Down Expand Up @@ -1013,6 +1047,7 @@ def progress(tokens_processed, tokens_total):
if finish_reason is not None:
break

ctx.decode_end_at = time.perf_counter()
rqueue.put(None)

# Save the KV cache again
Expand Down Expand Up @@ -1268,6 +1303,7 @@ def generate_response(
tokens: Optional[List[int]] = None,
tool_calls: Optional[List[str]] = None,
reasoning_text: Optional[str] = None,
timings: Optional[Dict[str, Any]] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or
Expand Down Expand Up @@ -1345,6 +1381,8 @@ def generate_response(
response["usage"]["prompt_tokens_details"] = {
"cached_tokens": prompt_cache_count,
}
if timings is not None:
response["timings"] = timings

choice = response["choices"][0]

Expand Down Expand Up @@ -1450,6 +1488,9 @@ def keepalive_callback(processed, total):
tokens = []
token_logprobs = []
top_tokens = []
include_usage = (not self.stream) or bool(
self.stream_options and self.stream_options.get("include_usage")
)

try:
for gen in response:
Expand Down Expand Up @@ -1497,6 +1538,12 @@ def keepalive_callback(processed, total):

prev_state = gen.state

timings = None
if include_usage:
cache_n = max(ctx.prompt_cache_count, 0)
prompt_n = len(ctx.prompt) - cache_n
timings = _make_timings(ctx, prompt_n, len(tokens), cache_n)

if prev_state == "tool" and tool_text:
tool_calls.append(tool_text)
made_tool_call = True
Expand All @@ -1513,14 +1560,12 @@ def keepalive_callback(processed, total):
)
self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode())
self.wfile.flush()
if (
self.stream_options is not None
and self.stream_options["include_usage"]
):
if include_usage:
resp = self.completion_usage_response(
len(ctx.prompt),
len(tokens),
ctx.prompt_cache_count,
timings=timings,
)
self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode())
self.wfile.flush()
Expand All @@ -1538,6 +1583,7 @@ def keepalive_callback(processed, total):
tokens=tokens,
reasoning_text=reasoning_text,
tool_calls=tool_formatter(tool_calls),
timings=timings,
)
if logging.getLogger().isEnabledFor(logging.DEBUG):
response_debug = json.dumps(resp, indent="\t")
Expand All @@ -1556,6 +1602,7 @@ def completion_usage_response(
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
prompt_cache_count: Optional[int] = None,
timings: Optional[Dict[str, Any]] = None,
):
response = {
"id": self.request_id,
Expand All @@ -1574,6 +1621,8 @@ def completion_usage_response(
response["usage"]["prompt_tokens_details"] = {
"cached_tokens": prompt_cache_count,
}
if timings is not None:
response["timings"] = timings
return response

def handle_chat_completions(self) -> CompletionRequest:
Expand Down
66 changes: 63 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,30 @@
from mlx_lm.utils import load


def assert_usage_timings(test_case, response_body):
test_case.assertIn("usage", response_body)
test_case.assertIn("timings", response_body)
timings = response_body["timings"]
cached_tokens = (
response_body["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
)
test_case.assertEqual(
timings["prompt_n"],
response_body["usage"]["prompt_tokens"] - cached_tokens,
)
test_case.assertEqual(
timings["predicted_n"],
response_body["usage"]["completion_tokens"],
)
test_case.assertEqual(timings["cache_n"], cached_tokens)
test_case.assertGreaterEqual(timings["prompt_ms"], 0)
test_case.assertGreaterEqual(timings["predicted_ms"], 0)
if timings["prompt_ms"] > 0:
test_case.assertGreater(timings["prompt_per_second"], 0)
if timings["predicted_ms"] > 0:
test_case.assertGreater(timings["predicted_per_second"], 0)


class DummyModelProvider:
def __init__(self, with_draft=False):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
Expand Down Expand Up @@ -201,6 +225,7 @@ def test_handle_completions(self):

self.assertIn("id", response_body)
self.assertIn("choices", response_body)
assert_usage_timings(self, response_body)
first_text = response_body["choices"][0]["text"]
self.assertEqual(
first_text,
Expand All @@ -221,9 +246,10 @@ def test_handle_chat_completions(self):
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
response_body = json.loads(response.text)
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
assert_usage_timings(self, response_body)

def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
Expand All @@ -248,6 +274,38 @@ def test_handle_chat_completions_with_content_fragments(self):
self.assertIn("id", response_body)
self.assertIn("choices", response_body)

def test_streaming_include_usage_timings(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
"stream_options": {"include_usage": True},
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}

response = requests.post(url, json=chat_post_data, stream=True)
self.assertEqual(response.status_code, 200)

usage_count = 0
for chunk in response.iter_lines():
if not chunk:
continue
data = chunk.decode("utf-8")
if not data.startswith("data: ") or data == "data: [DONE]":
continue
chunk_data = json.loads(data[6:])
if chunk_data.get("usage"):
self.assertEqual(chunk_data["choices"], [])
assert_usage_timings(self, chunk_data)
usage_count += 1

self.assertEqual(usage_count, 1)

def test_handle_chat_completions_with_null_tool_content(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
Expand Down Expand Up @@ -361,7 +419,7 @@ def test_handle_completions_with_draft_model(self):
response_body = json.loads(response.text)
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
self.assertIn("usage", response_body)
assert_usage_timings(self, response_body)

# Check that tokens were generated
self.assertTrue(response_body["usage"]["completion_tokens"] > 0)
Expand All @@ -385,7 +443,7 @@ def test_handle_chat_completions_with_draft_model(self):
response_body = json.loads(response.text)
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
self.assertIn("usage", response_body)
assert_usage_timings(self, response_body)

# Check that tokens were generated
self.assertTrue(response_body["usage"]["completion_tokens"] > 0)
Expand Down Expand Up @@ -458,6 +516,8 @@ def test_prompt_cache_with_draft_model(self):

self.assertIn("choices", first_response_body)
self.assertIn("choices", second_response_body)
assert_usage_timings(self, first_response_body)
assert_usage_timings(self, second_response_body)
self.assertIn("message", first_response_body["choices"][0])
self.assertIn("message", second_response_body["choices"][0])
self.assertIn("content", first_response_body["choices"][0]["message"])
Expand Down