diff --git a/benchmark_serving.py b/benchmark_serving.py index 43038df..f0d43d1 100644 --- a/benchmark_serving.py +++ b/benchmark_serving.py @@ -28,7 +28,7 @@ import requests import time from typing import AsyncGenerator, List, Optional, Tuple, Dict -from prometheus_client import start_http_server, Histogram, Gauge +from prometheus_client import start_http_server, Histogram, Gauge, Counter import logging import google.auth @@ -53,10 +53,38 @@ tpot_metric = Histogram('LatencyProfileGenerator:time_per_output_token', 'Time per output token per request (excluding first token)') ttft_metric = Histogram('LatencyProfileGenerator:time_to_first_token', 'Time to first token per request') active_requests_metric = Gauge('LatencyProfileGenerator:active_requests', 'How many requests actively being processed') +total_request_count = Counter('LatencyProfileGenerator:request_count', 'How many total requests have been sent') + +# Singleton class to track requests for QPS counting and calculation. +class AsyncRequestCounter: + _instance = None + _lock = asyncio.Lock() + + async def __new__(cls, target_requests=None, *args, **kwargs): + async with cls._lock: + if not cls._instance: + cls._instance = super().__new__(cls) + cls._instance._count = 0 + cls._instance._start_time = time.time() + cls._instance._target_requests = target_requests + return cls._instance + + async def increment(self): + async with self._lock: + self._count += 1 + if self._count == self._target_requests: + self._end_time = time.time() + + async def get_qps(self): + return self._count / (self._end_time - self._start_time) + # Add trace config for monitoring in flight requests async def on_request_start(session, trace_config_ctx, params): active_requests_metric.inc() + total_request_count.inc() + counter = await AsyncRequestCounter() + await counter.increment() async def on_request_end(session, trace_config_ctx, params): active_requests_metric.dec() @@ -454,6 +482,8 @@ async def benchmark( model_weights = list(models_dict.values()) benchmark_start_time = time.time() + # Initialize the counter with target prompts + await AsyncRequestCounter(args.num_prompts) tasks: List[asyncio.Task] = [] prompts_sent = 0 async for request in generate_next_request(input_requests, args.request_rate): @@ -495,12 +525,12 @@ async def benchmark( benchmark_duration = time.time() - benchmark_start_time - print_and_save_result(args, benchmark_duration, prompts_sent, "weighted", + await print_and_save_result(args, benchmark_duration, prompts_sent, "weighted", overall_results["latencies"], overall_results["ttfts"], overall_results["itls"], overall_results["tpots"], overall_results["errors"]) for model, data in per_model_results.items(): - print_and_save_result(args, benchmark_duration, len(data["latencies"]), model, + await print_and_save_result(args, benchmark_duration, len(data["latencies"]), model, data["latencies"], data["ttfts"], data["itls"], data["tpots"], data["errors"]) @@ -516,6 +546,7 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics "num_prompts_attempted": benchmark_result['num_prompts_attempted'], "num_prompts_succeeded": benchmark_result['num_prompts_succeeded'], "request_rate": args.request_rate, + "queries_per_second": benchmark_result['queries_per_second'], 'server_metrics': { **server_metrics }, @@ -742,6 +773,7 @@ def print_metrics(metrics: List[str], duration: float, namespace: str, job: str) logger.debug("HTTP Error: %s" % (response)) continue server_metrics[metric] = metric_results + return server_metrics def get_stats_for_set(name, description, points): @@ -765,7 +797,7 @@ def get_stats_for_set(name, description, points): f'p99_{name}': p99, } -def print_and_save_result(args: argparse.Namespace, benchmark_duration, total_requests, model, request_latencies, ttfts, itls, tpots, errors): +async def print_and_save_result(args: argparse.Namespace, benchmark_duration, total_requests, model, request_latencies, ttfts, itls, tpots, errors): benchmark_result = {} print(f"====Result for Model: {model}====") @@ -773,6 +805,10 @@ def print_and_save_result(args: argparse.Namespace, benchmark_duration, total_re print(f"Total time: {benchmark_duration:.2f} s") print(f"Successful/total requests: {len(request_latencies)}/{total_requests}") print(f"Requests/min: {60 * total_requests / benchmark_duration:.2f}") + counter = await AsyncRequestCounter() + queries_per_second = await counter.get_qps() + print(f"Queries/sec: {queries_per_second:.2f}") + benchmark_result['queries_per_second'] = queries_per_second benchmark_result["num_prompts_attempted"] = total_requests benchmark_result["num_prompts_succeeded"] = len(request_latencies) benchmark_result['benchmark_time'] = benchmark_duration