diff --git a/mlx_lm/SERVER.md b/mlx_lm/SERVER.md index f38ad3dd4..22eaebf95 100644 --- a/mlx_lm/SERVER.md +++ b/mlx_lm/SERVER.md @@ -140,6 +140,23 @@ curl localhost:8080/v1/chat/completions \ - `completion_tokens`: The number of tokens generated. - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. +- `timings`: Server-side timing measurements, following the shape used by + llama.cpp and several open-source clients. Times are measured around + the generation service only (no network or serialization) and include + some internal scheduling overhead, so treat them as approximate. + - `prompt_n`: Prompt tokens processed (excludes cached tokens). + - `prompt_ms`: Time to first generated token, in milliseconds. + - `prompt_per_second`: `prompt_n / (prompt_ms / 1000)`, or `0` if + `prompt_ms` is `0`. + - `predicted_n`: Tokens generated. + - `predicted_ms`: Time from first to last generated token, in + milliseconds. `0` when fewer than two tokens are generated. + - `predicted_per_second`: `predicted_n / (predicted_ms / 1000)`, or + `0` if `predicted_ms` is `0`. + + For streaming requests, `timings` rides on the final usage chunk and + requires `stream_options.include_usage`. + ### List Models Use the `v1/models` endpoint to list available models: diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..1d7234fd3 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -215,6 +215,9 @@ class GenerationContext: prompt: List[int] prompt_cache_count: int = -1 + prompt_start_at: Optional[float] = None + prompt_end_at: Optional[float] = None + decode_end_at: Optional[float] = None _should_stop: bool = False @@ -437,6 +440,25 @@ def _format_top_logprobs(logprobs, top_n, tokenizer) -> Tuple[Dict[str, Any]]: ) +def _make_timings(ctx, prompt_n: int, predicted_n: int, cache_n: int) -> Dict[str, Any]: + def elapsed(start, end): + if start is None or end is None or end <= start: + return 0 + return end - start + + prompt_s = elapsed(ctx.prompt_start_at, ctx.prompt_end_at) + predicted_s = elapsed(ctx.prompt_end_at, ctx.decode_end_at) + return { + "prompt_n": prompt_n, + "prompt_ms": prompt_s * 1000, + "prompt_per_second": (prompt_n / prompt_s) if prompt_s else 0, + "predicted_n": predicted_n, + "predicted_ms": predicted_s * 1000, + "predicted_per_second": (predicted_n / predicted_s) if predicted_s else 0, + "cache_n": cache_n, + } + + class ResponseGenerator: def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache): self.model_provider = model_provider @@ -850,15 +872,20 @@ def get_next_request(timeout=None): uids_to_remove = [] for _ in self._time_budget: + tic = time.perf_counter() prompt_responses, gen_responses = batch_generator.next() + toc = time.perf_counter() if not prompt_responses and not gen_responses: break # Progress report for prompt processing for r in prompt_responses: result = batch_results[r.uid] + ctx = result["ctx"] + if ctx.prompt_start_at is None: + ctx.prompt_start_at = tic result["rqueue"].put(r.progress) - if result["ctx"]._should_stop: + if ctx._should_stop: uids_to_remove.append(r.uid) # Save the caches at end of segments @@ -881,6 +908,9 @@ def get_next_request(timeout=None): for r in gen_responses: result = batch_results[r.uid] + ctx = result["ctx"] + if ctx.prompt_end_at is None: + ctx.prompt_end_at = toc result["detokenizer"].add_token(r.token) result["rqueue"].put( Response( @@ -899,6 +929,7 @@ def get_next_request(timeout=None): ) if r.finish_reason is not None: + ctx.decode_end_at = toc result["rqueue"].put(None) self.prompt_cache.insert_cache( current_model_key, @@ -908,7 +939,7 @@ def get_next_request(timeout=None): ) del batch_results[r.uid] - if result["ctx"]._should_stop: + if ctx._should_stop: uids_to_remove.append(r.uid) uids_to_remove = self._share_object(uids_to_remove) @@ -973,6 +1004,7 @@ def progress(tokens_processed, tokens_total): cache += make_prompt_cache(self.model_provider.draft_model) # Process the prompt and generate tokens + ctx.prompt_start_at = time.perf_counter() for gen in stream_generate( model=model, tokenizer=tokenizer, @@ -986,6 +1018,8 @@ def progress(tokens_processed, tokens_total): prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, ): + if ctx.prompt_end_at is None: + ctx.prompt_end_at = time.perf_counter() finish_reason = gen.finish_reason sm_state, match_sequence, current_state = sm.match(sm_state, gen.token) if match_sequence is not None and current_state is None: @@ -1013,6 +1047,7 @@ def progress(tokens_processed, tokens_total): if finish_reason is not None: break + ctx.decode_end_at = time.perf_counter() rqueue.put(None) # Save the KV cache again @@ -1268,6 +1303,7 @@ def generate_response( tokens: Optional[List[int]] = None, tool_calls: Optional[List[str]] = None, reasoning_text: Optional[str] = None, + timings: Optional[Dict[str, Any]] = None, ) -> dict: """ Generate a single response packet based on response type (stream or @@ -1345,6 +1381,8 @@ def generate_response( response["usage"]["prompt_tokens_details"] = { "cached_tokens": prompt_cache_count, } + if timings is not None: + response["timings"] = timings choice = response["choices"][0] @@ -1450,6 +1488,9 @@ def keepalive_callback(processed, total): tokens = [] token_logprobs = [] top_tokens = [] + include_usage = (not self.stream) or bool( + self.stream_options and self.stream_options.get("include_usage") + ) try: for gen in response: @@ -1497,6 +1538,12 @@ def keepalive_callback(processed, total): prev_state = gen.state + timings = None + if include_usage: + cache_n = max(ctx.prompt_cache_count, 0) + prompt_n = len(ctx.prompt) - cache_n + timings = _make_timings(ctx, prompt_n, len(tokens), cache_n) + if prev_state == "tool" and tool_text: tool_calls.append(tool_text) made_tool_call = True @@ -1513,14 +1560,12 @@ def keepalive_callback(processed, total): ) self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) self.wfile.flush() - if ( - self.stream_options is not None - and self.stream_options["include_usage"] - ): + if include_usage: resp = self.completion_usage_response( len(ctx.prompt), len(tokens), ctx.prompt_cache_count, + timings=timings, ) self.wfile.write(f"data: {json.dumps(resp)}\n\n".encode()) self.wfile.flush() @@ -1538,6 +1583,7 @@ def keepalive_callback(processed, total): tokens=tokens, reasoning_text=reasoning_text, tool_calls=tool_formatter(tool_calls), + timings=timings, ) if logging.getLogger().isEnabledFor(logging.DEBUG): response_debug = json.dumps(resp, indent="\t") @@ -1556,6 +1602,7 @@ def completion_usage_response( prompt_token_count: Optional[int] = None, completion_token_count: Optional[int] = None, prompt_cache_count: Optional[int] = None, + timings: Optional[Dict[str, Any]] = None, ): response = { "id": self.request_id, @@ -1574,6 +1621,8 @@ def completion_usage_response( response["usage"]["prompt_tokens_details"] = { "cached_tokens": prompt_cache_count, } + if timings is not None: + response["timings"] = timings return response def handle_chat_completions(self) -> CompletionRequest: diff --git a/tests/test_server.py b/tests/test_server.py index 9a8a2ad14..09c20d26c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -21,6 +21,30 @@ from mlx_lm.utils import load +def assert_usage_timings(test_case, response_body): + test_case.assertIn("usage", response_body) + test_case.assertIn("timings", response_body) + timings = response_body["timings"] + cached_tokens = ( + response_body["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) + test_case.assertEqual( + timings["prompt_n"], + response_body["usage"]["prompt_tokens"] - cached_tokens, + ) + test_case.assertEqual( + timings["predicted_n"], + response_body["usage"]["completion_tokens"], + ) + test_case.assertEqual(timings["cache_n"], cached_tokens) + test_case.assertGreaterEqual(timings["prompt_ms"], 0) + test_case.assertGreaterEqual(timings["predicted_ms"], 0) + if timings["prompt_ms"] > 0: + test_case.assertGreater(timings["prompt_per_second"], 0) + if timings["predicted_ms"] > 0: + test_case.assertGreater(timings["predicted_per_second"], 0) + + class DummyModelProvider: def __init__(self, with_draft=False): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" @@ -201,6 +225,7 @@ def test_handle_completions(self): self.assertIn("id", response_body) self.assertIn("choices", response_body) + assert_usage_timings(self, response_body) first_text = response_body["choices"][0]["text"] self.assertEqual( first_text, @@ -221,9 +246,10 @@ def test_handle_chat_completions(self): ], } response = requests.post(url, json=chat_post_data) - response_body = response.text + response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) + assert_usage_timings(self, response_body) def test_handle_chat_completions_with_content_fragments(self): url = f"http://localhost:{self.port}/v1/chat/completions" @@ -248,6 +274,38 @@ def test_handle_chat_completions_with_content_fragments(self): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_streaming_include_usage_timings(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.0, + "stream": True, + "stream_options": {"include_usage": True}, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + response = requests.post(url, json=chat_post_data, stream=True) + self.assertEqual(response.status_code, 200) + + usage_count = 0 + for chunk in response.iter_lines(): + if not chunk: + continue + data = chunk.decode("utf-8") + if not data.startswith("data: ") or data == "data: [DONE]": + continue + chunk_data = json.loads(data[6:]) + if chunk_data.get("usage"): + self.assertEqual(chunk_data["choices"], []) + assert_usage_timings(self, chunk_data) + usage_count += 1 + + self.assertEqual(usage_count, 1) + def test_handle_chat_completions_with_null_tool_content(self): url = f"http://localhost:{self.port}/v1/chat/completions" chat_post_data = { @@ -361,7 +419,7 @@ def test_handle_completions_with_draft_model(self): response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) - self.assertIn("usage", response_body) + assert_usage_timings(self, response_body) # Check that tokens were generated self.assertTrue(response_body["usage"]["completion_tokens"] > 0) @@ -385,7 +443,7 @@ def test_handle_chat_completions_with_draft_model(self): response_body = json.loads(response.text) self.assertIn("id", response_body) self.assertIn("choices", response_body) - self.assertIn("usage", response_body) + assert_usage_timings(self, response_body) # Check that tokens were generated self.assertTrue(response_body["usage"]["completion_tokens"] > 0) @@ -458,6 +516,8 @@ def test_prompt_cache_with_draft_model(self): self.assertIn("choices", first_response_body) self.assertIn("choices", second_response_body) + assert_usage_timings(self, first_response_body) + assert_usage_timings(self, second_response_body) self.assertIn("message", first_response_body["choices"][0]) self.assertIn("message", second_response_body["choices"][0]) self.assertIn("content", first_response_body["choices"][0]["message"])