From 6ba776caa51f7f7d9e352884464c6015f5c678e4 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sat, 16 May 2026 20:35:46 +0800 Subject: [PATCH 1/5] Add timings to server responses --- mlx_lm/SERVER.md | 16 +++++++++++ mlx_lm/server.py | 57 ++++++++++++++++++++++++++++++++++---- tests/test_server.py | 65 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 129 insertions(+), 9 deletions(-) diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index f38ad3dd4..ad6c8af0a 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -140,6 +140,22 @@ curl localhost:8080/v1/chat/completions \ - `completion_tokens`: The number of tokens generated. - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. +- `timings`: A dictionary with server-observed generation-service + measurements. Excludes network I/O, response serialization, and + client-side wait. Includes overhead inside the generation service + (queue hops, batch-tick scheduling), so values are approximate: + - `prompt_n`: The number of prompt tokens actually processed (excludes + cached tokens). + - `predicted_n`: The number of tokens generated. + - `prompt_per_second`: `prompt_n` divided by the server-observed time + until the first generated token is produced. + - `predicted_per_second`: `predicted_n` divided by the server-observed + time between the first and last generated tokens. Returns `0` when + fewer than two tokens are generated. + + For streaming requests, `timings` is included in the final usage chunk + when `stream_options.include_usage` is set. + ### List Models Use the `v1/models` endpoint to list available models: diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..44f793922 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -215,6 +215,9 @@ class GenerationContext: prompt: List[int] prompt_cache_count: int = -1 + prompt_start_at: Optional[float] = None + prompt_end_at: Optional[float] = None + decode_end_at: Optional[float] = None _should_stop: bool = False @@ -437,6 +440,22 @@ def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: ) +def _make_timings(ctx, prompt_n: int, predicted_n: int) -> Dict[str, Any]: + def rate(n, start, end): + if start is None or end is None or end <= start: + return 0 + return n / (end - start) + + return { + "prompt_per_second": rate(prompt_n, ctx.prompt_start_at, ctx.prompt_end_at), + "predicted_per_second": rate( + predicted_n, ctx.prompt_end_at, ctx.decode_end_at + ), + "prompt_n": prompt_n, + "predicted_n": predicted_n, + } + + class ResponseGenerator: def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache): self.model_provider = model_provider @@ -850,15 +869,20 @@ def get_next_request(timeout=None): uids_to_remove = [] for _ in self._time_budget: + tic = time.perf_counter() prompt_responses, gen_responses = batch_generator.next() + toc = time.perf_counter() if not prompt_responses and not gen_responses: break # Progress report for prompt processing for r in prompt_responses: result = batch_results[r.uid] + ctx = result["ctx"] + if ctx.prompt_start_at is None: + ctx.prompt_start_at = tic result["rqueue"].put(r.progress) - if result["ctx"]._should_stop: + if ctx._should_stop: uids_to_remove.append(r.uid) # Save the caches at end of segments @@ -881,6 +905,9 @@ def get_next_request(timeout=None): for r in gen_responses: result = batch_results[r.uid] + ctx = result["ctx"] + if ctx.prompt_end_at is None: + ctx.prompt_end_at = toc result["detokenizer"].add_token(r.token) result["rqueue"].put( Response( @@ -899,6 +926,7 @@ def get_next_request(timeout=None): ) if r.finish_reason is not None: + ctx.decode_end_at = toc result["rqueue"].put(None) self.prompt_cache.insert_cache( current_model_key, @@ -908,7 +936,7 @@ def get_next_request(timeout=None): ) del batch_results[r.uid] - if result["ctx"]._should_stop: + if ctx._should_stop: uids_to_remove.append(r.uid) uids_to_remove = self._share_object(uids_to_remove) @@ -973,6 +1001,7 @@ def progress(tokens_processed, tokens_total): cache += make_prompt_cache(self.model_provider.draft_model) # Process the prompt and generate tokens + ctx.prompt_start_at = time.perf_counter() for gen in stream_generate( model=model, tokenizer=tokenizer, @@ -986,6 +1015,8 @@ def progress(tokens_processed, tokens_total): prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, ): + if ctx.prompt_end_at is None: + ctx.prompt_end_at = time.perf_counter() finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) if match_sequence is not None and current_state is None: @@ -1013,6 +1044,7 @@ def progress(tokens_processed, tokens_total): if finish_reason is not None: break + ctx.decode_end_at = time.perf_counter() rqueue.put(None) # Save the KV cache again @@ -1268,6 +1300,7 @@ def generate_response( tokens: Optional[List[int]] = None, tool_calls: Optional[List[str]] = None, reasoning_text: Optional[str] = None, + timings: Optional[Dict[str, Any]] = None, ) -> dict: """ Generate a single response packet based on response type (stream or @@ -1345,6 +1378,8 @@ def generate_response( response["usage"]["prompt_tokens_details"] = { "cached_tokens": prompt_cache_count, } + if timings is not None: + response["timings"] = timings choice = response["choices"][0] @@ -1450,6 +1485,9 @@ def keepalive_callback(processed, total): tokens = [] token_logprobs = [] top_tokens = [] + include_usage = (not self.stream) or bool( + self.stream_options and self.stream_options.get("include_usage") + ) try: for gen in response: @@ -1497,6 +1535,11 @@ def keepalive_callback(processed, total): prev_state = gen.state + timings = None + if include_usage: + prompt_n = len(ctx.prompt) - max(ctx.prompt_cache_count, 0) + timings = _make_timings(ctx, prompt_n, len(tokens)) + if prev_state == "tool" and tool_text: tool_calls.append(tool_text) made_tool_call = True @@ -1513,14 +1556,12 @@ def keepalive_callback(processed, total): ) self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) self.wfile.flush() - if ( - self.stream_options is not None - and self.stream_options["include_usage"] - ): + if include_usage: resp = self.completion_usage_response( len(ctx.prompt), len(tokens), ctx.prompt_cache_count, + timings=timings, ) self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) self.wfile.flush() @@ -1538,6 +1579,7 @@ def keepalive_callback(processed, total): tokens=tokens, reasoning_text=reasoning_text, tool_calls=tool_formatter(tool_calls), + timings=timings, ) if logging.getLogger().isEnabledFor(logging.DEBUG): response_debug = json.dumps(resp, indent="\t") @@ -1556,6 +1598,7 @@ def completion_usage_response( prompt_token_count: Optional[int] = None, completion_token_count: Optional[int] = None, prompt_cache_count: Optional[int] = None, + timings: Optional[Dict[str, Any]] = None, ): response = { "id": self.request_id, @@ -1574,6 +1617,8 @@ def completion_usage_response( response["usage"]["prompt_tokens_details"] = { "cached_tokens": prompt_cache_count, } + if timings is not None: + response["timings"] = timings return response def handle_chat_completions(self) -> CompletionRequest: diff --git a/tests/test_server.py b/tests/test_server.py index 9a8a2ad14..26cf25b0d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,6 +21,29 @@ from mlx_lm.utils import load +def assert_usage_timings(test_case, response_body): + test_case.assertIn("usage", response_body) + test_case.assertIn("timings", response_body) + timings = response_body["timings"] + cached_tokens = ( + response_body["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) + test_case.assertEqual( + timings["prompt_n"], + response_body["usage"]["prompt_tokens"] - cached_tokens, + ) + test_case.assertEqual( + timings["predicted_n"], + response_body["usage"]["completion_tokens"], + ) + if timings["prompt_n"] > 0: + test_case.assertGreater(timings["prompt_per_second"], 0) + # predicted_per_second needs at least two tokens; with one, the + # first-token and last-token timestamps collapse. + if timings["predicted_n"] > 1: + test_case.assertGreater(timings["predicted_per_second"], 0) + + class DummyModelProvider: def __init__(self, with_draft=False): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -201,6 +224,7 @@ def test_handle_completions(self): self.assertIn("id", response_body) self.assertIn("choices", response_body) + assert_usage_timings(self, response_body) first_text = response_body["choices"][0]["text"] self.assertEqual( first_text, @@ -221,9 +245,10 @@ def test_handle_chat_completions(self): ], } response = requests.post(url, json=chat_post_data) - response_body = response.text + response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) + assert_usage_timings(self, response_body) def test_handle_chat_completions_with_content_fragments(self): url = f"http://localhost:{self.port}/v1/chat/completions" @@ -248,6 +273,38 @@ def test_handle_chat_completions_with_content_fragments(self): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_streaming_include_usage_timings(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.0, + "stream": True, + "stream_options": {"include_usage": True}, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + response = requests.post(url, json=chat_post_data, stream=True) + self.assertEqual(response.status_code, 200) + + usage_count = 0 + for chunk in response.iter_lines(): + if not chunk: + continue + data = chunk.decode("utf-8") + if not data.startswith("data: ") or data == "data: [DONE]": + continue + chunk_data = json.loads(data[6:]) + if chunk_data.get("usage"): + self.assertEqual(chunk_data["choices"], []) + assert_usage_timings(self, chunk_data) + usage_count += 1 + + self.assertEqual(usage_count, 1) + def test_handle_chat_completions_with_null_tool_content(self): url = f"http://localhost:{self.port}/v1/chat/completions" chat_post_data = { @@ -361,7 +418,7 @@ def test_handle_completions_with_draft_model(self): response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) - self.assertIn("usage", response_body) + assert_usage_timings(self, response_body) # Check that tokens were generated self.assertTrue(response_body["usage"]["completion_tokens"] > 0) @@ -385,7 +442,7 @@ def test_handle_chat_completions_with_draft_model(self): response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) - self.assertIn("usage", response_body) + assert_usage_timings(self, response_body) # Check that tokens were generated self.assertTrue(response_body["usage"]["completion_tokens"] > 0) @@ -458,6 +515,8 @@ def test_prompt_cache_with_draft_model(self): self.assertIn("choices", first_response_body) self.assertIn("choices", second_response_body) + assert_usage_timings(self, first_response_body) + assert_usage_timings(self, second_response_body) self.assertIn("message", first_response_body["choices"][0]) self.assertIn("message", second_response_body["choices"][0]) self.assertIn("content", first_response_body["choices"][0]["message"]) From e154220ed77ff4e1d2e823964abc3bd283f1a3cf Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 17 May 2026 16:20:54 +0800 Subject: [PATCH 2/5] Add cache_n support --- mlx_lm/server.py | 5 +++-- tests/test_server.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 44f793922..f150a2f87 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -440,7 +440,7 @@ def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: ) -def _make_timings(ctx, prompt_n: int, predicted_n: int) -> Dict[str, Any]: +def _make_timings(ctx, prompt_n: int, predicted_n: int, cache_n: int) -> Dict[str, Any]: def rate(n, start, end): if start is None or end is None or end <= start: return 0 @@ -453,6 +453,7 @@ def rate(n, start, end): ), "prompt_n": prompt_n, "predicted_n": predicted_n, + "cache_n": cache_n, } @@ -1538,7 +1539,7 @@ def keepalive_callback(processed, total): timings = None if include_usage: prompt_n = len(ctx.prompt) - max(ctx.prompt_cache_count, 0) - timings = _make_timings(ctx, prompt_n, len(tokens)) + timings = _make_timings(ctx, prompt_n, len(tokens), max(ctx.prompt_cache_count, 0)) if prev_state == "tool" and tool_text: tool_calls.append(tool_text) diff --git a/tests/test_server.py b/tests/test_server.py index 26cf25b0d..c2448a122 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -36,6 +36,7 @@ def assert_usage_timings(test_case, response_body): timings["predicted_n"], response_body["usage"]["completion_tokens"], ) + test_case.assertEqual(timings["cache_n"], cached_tokens) if timings["prompt_n"] > 0: test_case.assertGreater(timings["prompt_per_second"], 0) # predicted_per_second needs at least two tokens; with one, the From a6b8e59caadef7d4f9834d051283c8dcd73804b7 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 17 May 2026 17:32:06 +0800 Subject: [PATCH 3/5] Report elapsed time in timings --- mlx_lm/server.py | 14 ++++++++------ tests/test_server.py | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index f150a2f87..119626202 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -441,18 +441,20 @@ def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: def _make_timings(ctx, prompt_n: int, predicted_n: int, cache_n: int) -> Dict[str, Any]: - def rate(n, start, end): + def elapsed(start, end): if start is None or end is None or end <= start: return 0 - return n / (end - start) + return end - start + prompt_s = elapsed(ctx.prompt_start_at, ctx.prompt_end_at) + predicted_s = elapsed(ctx.prompt_end_at, ctx.decode_end_at) return { - "prompt_per_second": rate(prompt_n, ctx.prompt_start_at, ctx.prompt_end_at), - "predicted_per_second": rate( - predicted_n, ctx.prompt_end_at, ctx.decode_end_at - ), "prompt_n": prompt_n, + "prompt_ms": prompt_s * 1000, + "prompt_per_second": (prompt_n / prompt_s) if prompt_s else 0, "predicted_n": predicted_n, + "predicted_ms": predicted_s * 1000, + "predicted_per_second": (predicted_n / predicted_s) if predicted_s else 0, "cache_n": cache_n, } diff --git a/tests/test_server.py b/tests/test_server.py index c2448a122..09c20d26c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -37,11 +37,11 @@ def assert_usage_timings(test_case, response_body): response_body["usage"]["completion_tokens"], ) test_case.assertEqual(timings["cache_n"], cached_tokens) - if timings["prompt_n"] > 0: + test_case.assertGreaterEqual(timings["prompt_ms"], 0) + test_case.assertGreaterEqual(timings["predicted_ms"], 0) + if timings["prompt_ms"] > 0: test_case.assertGreater(timings["prompt_per_second"], 0) - # predicted_per_second needs at least two tokens; with one, the - # first-token and last-token timestamps collapse. - if timings["predicted_n"] > 1: + if timings["predicted_ms"] > 0: test_case.assertGreater(timings["predicted_per_second"], 0) From aad8db33a7294773ee1d49d44e9a5aa205b3df8b Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 17 May 2026 17:32:11 +0800 Subject: [PATCH 4/5] Update SERVER.md --- mlx_lm/SERVER.md | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index ad6c8af0a..22eaebf95 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -140,21 +140,22 @@ curl localhost:8080/v1/chat/completions \ - `completion_tokens`: The number of tokens generated. - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. -- `timings`: A dictionary with server-observed generation-service - measurements. Excludes network I/O, response serialization, and - client-side wait. Includes overhead inside the generation service - (queue hops, batch-tick scheduling), so values are approximate: - - `prompt_n`: The number of prompt tokens actually processed (excludes - cached tokens). - - `predicted_n`: The number of tokens generated. - - `prompt_per_second`: `prompt_n` divided by the server-observed time - until the first generated token is produced. - - `predicted_per_second`: `predicted_n` divided by the server-observed - time between the first and last generated tokens. Returns `0` when - fewer than two tokens are generated. - - For streaming requests, `timings` is included in the final usage chunk - when `stream_options.include_usage` is set. +- `timings`: Server-side timing measurements, following the shape used by + llama.cpp and several open-source clients. Times are measured around + the generation service only (no network or serialization) and include + some internal scheduling overhead, so treat them as approximate. + - `prompt_n`: Prompt tokens processed (excludes cached tokens). + - `prompt_ms`: Time to first generated token, in milliseconds. + - `prompt_per_second`: `prompt_n / (prompt_ms / 1000)`, or `0` if + `prompt_ms` is `0`. + - `predicted_n`: Tokens generated. + - `predicted_ms`: Time from first to last generated token, in + milliseconds. `0` when fewer than two tokens are generated. + - `predicted_per_second`: `predicted_n / (predicted_ms / 1000)`, or + `0` if `predicted_ms` is `0`. + + For streaming requests, `timings` rides on the final usage chunk and + requires `stream_options.include_usage`. ### List Models From c07212170512ce9e9af3b154d2894a73af910537 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sun, 17 May 2026 17:39:12 +0800 Subject: [PATCH 5/5] Hoist cache_n --- mlx_lm/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 119626202..1d7234fd3 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1540,8 +1540,9 @@ def keepalive_callback(processed, total): timings = None if include_usage: - prompt_n = len(ctx.prompt) - max(ctx.prompt_cache_count, 0) - timings = _make_timings(ctx, prompt_n, len(tokens), max(ctx.prompt_cache_count, 0)) + cache_n = max(ctx.prompt_cache_count, 0) + prompt_n = len(ctx.prompt) - cache_n + timings = _make_timings(ctx, prompt_n, len(tokens), cache_n) if prev_state == "tool" and tool_text: tool_calls.append(tool_text)