2828import requests
2929import time
3030from typing import AsyncGenerator , List , Optional , Tuple , Dict
31- from prometheus_client import start_http_server , Histogram , Gauge
31+ from prometheus_client import start_http_server , Histogram , Gauge , Counter
3232import logging
3333
3434import google .auth
5353tpot_metric = Histogram ('LatencyProfileGenerator:time_per_output_token_ms' , 'Time per output token per request (excluding first token) (ms)' , buckets = [2 ** i for i in range (1 , 16 )])
5454ttft_metric = Histogram ('LatencyProfileGenerator:time_to_first_token_ms' , 'Time to first token per request (ms)' , buckets = [2 ** i for i in range (1 , 16 )])
5555active_requests_metric = Gauge ('LatencyProfileGenerator:active_requests' , 'How many requests actively being processed' )
56+ total_request_count = Counter ('LatencyProfileGenerator:request_count' , 'How many total requests have been sent' )
57+
58+ # Singleton class to track requests for QPS counting and calculation.
59+ class AsyncRequestCounter :
60+ _instance = None
61+ _lock = asyncio .Lock ()
62+
63+ async def __new__ (cls , target_requests = None , * args , ** kwargs ):
64+ async with cls ._lock :
65+ if not cls ._instance :
66+ cls ._instance = super ().__new__ (cls )
67+ cls ._instance ._count = 0
68+ cls ._instance ._start_time = time .time ()
69+ cls ._instance ._target_requests = target_requests
70+ return cls ._instance
71+
72+ async def increment (self ):
73+ async with self ._lock :
74+ self ._count += 1
75+ if self ._count == self ._target_requests :
76+ self ._end_time = time .time ()
77+
78+ async def get_qps (self ):
79+ return self ._count / (self ._end_time - self ._start_time )
80+
5681
5782# Add trace config for monitoring in flight requests
5883async def on_request_start (session , trace_config_ctx , params ):
5984 active_requests_metric .inc ()
85+ total_request_count .inc ()
86+ counter = await AsyncRequestCounter ()
87+ await counter .increment ()
6088
6189async def on_request_end (session , trace_config_ctx , params ):
6290 active_requests_metric .dec ()
@@ -460,6 +488,8 @@ async def benchmark(
460488 model_weights = list (models_dict .values ())
461489
462490 benchmark_start_time_sec = time .time ()
491+ # Initialize the counter with target prompts
492+ await AsyncRequestCounter (args .num_prompts )
463493 tasks : List [asyncio .Task ] = []
464494 prompts_sent = 0
465495 async for request in generate_next_request (input_requests , args .request_rate ):
@@ -501,12 +531,12 @@ async def benchmark(
501531
502532 benchmark_duration_sec = time .time () - benchmark_start_time_sec
503533
504- print_and_save_result (args , benchmark_duration_sec , prompts_sent , "weighted" ,
534+ await print_and_save_result (args , benchmark_duration_sec , prompts_sent , "weighted" ,
505535 overall_results ["latencies" ], overall_results ["ttfts" ],
506536 overall_results ["itls" ], overall_results ["tpots" ],
507537 overall_results ["errors" ])
508538 for model , data in per_model_results .items ():
509- print_and_save_result (args , benchmark_duration_sec , len (data ["latencies" ]), model ,
539+ await print_and_save_result (args , benchmark_duration_sec , len (data ["latencies" ]), model ,
510540 data ["latencies" ], data ["ttfts" ], data ["itls" ],
511541 data ["tpots" ], data ["errors" ])
512542
@@ -522,6 +552,7 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics
522552 "num_prompts_attempted" : benchmark_result ['num_prompts_attempted' ],
523553 "num_prompts_succeeded" : benchmark_result ['num_prompts_succeeded' ],
524554 "request_rate" : args .request_rate ,
555+ "queries_per_second" : benchmark_result ['queries_per_second' ],
525556 'server_metrics' : {
526557 ** server_metrics
527558 },
@@ -760,6 +791,7 @@ def print_metrics(metrics: List[str], duration_sec: float, namespace: str, job:
760791 logger .debug ("HTTP Error: %s" % (response ))
761792 continue
762793 server_metrics [metric ] = metric_results
794+
763795 return server_metrics
764796
765797def get_stats_for_set (name , description , points ):
@@ -783,14 +815,18 @@ def get_stats_for_set(name, description, points):
783815 f'p99_{ name } ' : p99 ,
784816 }
785817
786- def print_and_save_result (args : argparse .Namespace , benchmark_duration_sec , total_requests , model , request_latencies , ttfts , itls , tpots , errors ):
818+ async def print_and_save_result (args : argparse .Namespace , benchmark_duration_sec , total_requests , model , request_latencies , ttfts , itls , tpots , errors ):
787819 benchmark_result = {}
788820
789821 print (f"====Result for Model: { model } ====" )
790822 print (f"Errors: { errors } " )
791823 print (f"Total time (seconds): { benchmark_duration_sec :.2f} s" )
792824 print (f"Successful/total requests: { len (request_latencies )} /{ total_requests } " )
793825 print (f"Requests/sec: { total_requests / benchmark_duration_sec :.2f} " )
826+ counter = await AsyncRequestCounter ()
827+ queries_per_second = await counter .get_qps ()
828+ print (f"Queries/sec: { queries_per_second :.2f} " )
829+ benchmark_result ['queries_per_second' ] = queries_per_second
794830 benchmark_result ["num_prompts_attempted" ] = total_requests
795831 benchmark_result ["num_prompts_succeeded" ] = len (request_latencies )
796832 benchmark_result ['benchmark_time' ] = benchmark_duration_sec
0 commit comments