Skip to content

Commit fd39679

Browse files
authored
Qps observability (#32)
* Add prometheus metric for request count * Add singleton request counter Adds singleton counter that allows for calculating QPS.
1 parent 8e4a7a0 commit fd39679

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

benchmark_serving.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import requests
2929
import time
3030
from typing import AsyncGenerator, List, Optional, Tuple, Dict
31-
from prometheus_client import start_http_server, Histogram, Gauge
31+
from prometheus_client import start_http_server, Histogram, Gauge, Counter
3232
import logging
3333

3434
import google.auth
@@ -53,10 +53,38 @@
5353
tpot_metric = Histogram('LatencyProfileGenerator:time_per_output_token_ms', 'Time per output token per request (excluding first token) (ms)', buckets=[2**i for i in range(1, 16)])
5454
ttft_metric = Histogram('LatencyProfileGenerator:time_to_first_token_ms', 'Time to first token per request (ms)', buckets=[2**i for i in range(1, 16)])
5555
active_requests_metric = Gauge('LatencyProfileGenerator:active_requests', 'How many requests actively being processed')
56+
total_request_count = Counter('LatencyProfileGenerator:request_count', 'How many total requests have been sent')
57+
58+
# Singleton class to track requests for QPS counting and calculation.
59+
class AsyncRequestCounter:
60+
_instance = None
61+
_lock = asyncio.Lock()
62+
63+
async def __new__(cls, target_requests=None, *args, **kwargs):
64+
async with cls._lock:
65+
if not cls._instance:
66+
cls._instance = super().__new__(cls)
67+
cls._instance._count = 0
68+
cls._instance._start_time = time.time()
69+
cls._instance._target_requests = target_requests
70+
return cls._instance
71+
72+
async def increment(self):
73+
async with self._lock:
74+
self._count += 1
75+
if self._count == self._target_requests:
76+
self._end_time = time.time()
77+
78+
async def get_qps(self):
79+
return self._count / (self._end_time - self._start_time)
80+
5681

5782
# Add trace config for monitoring in flight requests
5883
async def on_request_start(session, trace_config_ctx, params):
5984
active_requests_metric.inc()
85+
total_request_count.inc()
86+
counter = await AsyncRequestCounter()
87+
await counter.increment()
6088

6189
async def on_request_end(session, trace_config_ctx, params):
6290
active_requests_metric.dec()
@@ -460,6 +488,8 @@ async def benchmark(
460488
model_weights = list(models_dict.values())
461489

462490
benchmark_start_time_sec = time.time()
491+
# Initialize the counter with target prompts
492+
await AsyncRequestCounter(args.num_prompts)
463493
tasks: List[asyncio.Task] = []
464494
prompts_sent = 0
465495
async for request in generate_next_request(input_requests, args.request_rate):
@@ -501,12 +531,12 @@ async def benchmark(
501531

502532
benchmark_duration_sec = time.time() - benchmark_start_time_sec
503533

504-
print_and_save_result(args, benchmark_duration_sec, prompts_sent, "weighted",
534+
await print_and_save_result(args, benchmark_duration_sec, prompts_sent, "weighted",
505535
overall_results["latencies"], overall_results["ttfts"],
506536
overall_results["itls"], overall_results["tpots"],
507537
overall_results["errors"])
508538
for model, data in per_model_results.items():
509-
print_and_save_result(args, benchmark_duration_sec, len(data["latencies"]), model,
539+
await print_and_save_result(args, benchmark_duration_sec, len(data["latencies"]), model,
510540
data["latencies"], data["ttfts"], data["itls"],
511541
data["tpots"], data["errors"])
512542

@@ -522,6 +552,7 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics
522552
"num_prompts_attempted": benchmark_result['num_prompts_attempted'],
523553
"num_prompts_succeeded": benchmark_result['num_prompts_succeeded'],
524554
"request_rate": args.request_rate,
555+
"queries_per_second": benchmark_result['queries_per_second'],
525556
'server_metrics': {
526557
**server_metrics
527558
},
@@ -760,6 +791,7 @@ def print_metrics(metrics: List[str], duration_sec: float, namespace: str, job:
760791
logger.debug("HTTP Error: %s" % (response))
761792
continue
762793
server_metrics[metric] = metric_results
794+
763795
return server_metrics
764796

765797
def get_stats_for_set(name, description, points):
@@ -783,14 +815,18 @@ def get_stats_for_set(name, description, points):
783815
f'p99_{name}': p99,
784816
}
785817

786-
def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors):
818+
async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors):
787819
benchmark_result = {}
788820

789821
print(f"====Result for Model: {model}====")
790822
print(f"Errors: {errors}")
791823
print(f"Total time (seconds): {benchmark_duration_sec:.2f} s")
792824
print(f"Successful/total requests: {len(request_latencies)}/{total_requests}")
793825
print(f"Requests/sec: {total_requests / benchmark_duration_sec:.2f}")
826+
counter = await AsyncRequestCounter()
827+
queries_per_second = await counter.get_qps()
828+
print(f"Queries/sec: {queries_per_second:.2f}")
829+
benchmark_result['queries_per_second'] = queries_per_second
794830
benchmark_result["num_prompts_attempted"] = total_requests
795831
benchmark_result["num_prompts_succeeded"] = len(request_latencies)
796832
benchmark_result['benchmark_time'] = benchmark_duration_sec

0 commit comments

Comments
 (0)