28
28
import requests
29
29
import time
30
30
from typing import AsyncGenerator , List , Optional , Tuple , Dict
31
- from prometheus_client import start_http_server , Histogram , Gauge
31
+ from prometheus_client import start_http_server , Histogram , Gauge , Counter
32
32
import logging
33
33
34
34
import google .auth
53
53
tpot_metric = Histogram ('LatencyProfileGenerator:time_per_output_token_ms' , 'Time per output token per request (excluding first token) (ms)' , buckets = [2 ** i for i in range (1 , 16 )])
54
54
ttft_metric = Histogram ('LatencyProfileGenerator:time_to_first_token_ms' , 'Time to first token per request (ms)' , buckets = [2 ** i for i in range (1 , 16 )])
55
55
active_requests_metric = Gauge ('LatencyProfileGenerator:active_requests' , 'How many requests actively being processed' )
56
+ total_request_count = Counter ('LatencyProfileGenerator:request_count' , 'How many total requests have been sent' )
57
+
58
+ # Singleton class to track requests for QPS counting and calculation.
59
+ class AsyncRequestCounter :
60
+ _instance = None
61
+ _lock = asyncio .Lock ()
62
+
63
+ async def __new__ (cls , target_requests = None , * args , ** kwargs ):
64
+ async with cls ._lock :
65
+ if not cls ._instance :
66
+ cls ._instance = super ().__new__ (cls )
67
+ cls ._instance ._count = 0
68
+ cls ._instance ._start_time = time .time ()
69
+ cls ._instance ._target_requests = target_requests
70
+ return cls ._instance
71
+
72
+ async def increment (self ):
73
+ async with self ._lock :
74
+ self ._count += 1
75
+ if self ._count == self ._target_requests :
76
+ self ._end_time = time .time ()
77
+
78
+ async def get_qps (self ):
79
+ return self ._count / (self ._end_time - self ._start_time )
80
+
56
81
57
82
# Add trace config for monitoring in flight requests
58
83
async def on_request_start (session , trace_config_ctx , params ):
59
84
active_requests_metric .inc ()
85
+ total_request_count .inc ()
86
+ counter = await AsyncRequestCounter ()
87
+ await counter .increment ()
60
88
61
89
async def on_request_end (session , trace_config_ctx , params ):
62
90
active_requests_metric .dec ()
@@ -460,6 +488,8 @@ async def benchmark(
460
488
model_weights = list (models_dict .values ())
461
489
462
490
benchmark_start_time_sec = time .time ()
491
+ # Initialize the counter with target prompts
492
+ await AsyncRequestCounter (args .num_prompts )
463
493
tasks : List [asyncio .Task ] = []
464
494
prompts_sent = 0
465
495
async for request in generate_next_request (input_requests , args .request_rate ):
@@ -501,12 +531,12 @@ async def benchmark(
501
531
502
532
benchmark_duration_sec = time .time () - benchmark_start_time_sec
503
533
504
- print_and_save_result (args , benchmark_duration_sec , prompts_sent , "weighted" ,
534
+ await print_and_save_result (args , benchmark_duration_sec , prompts_sent , "weighted" ,
505
535
overall_results ["latencies" ], overall_results ["ttfts" ],
506
536
overall_results ["itls" ], overall_results ["tpots" ],
507
537
overall_results ["errors" ])
508
538
for model , data in per_model_results .items ():
509
- print_and_save_result (args , benchmark_duration_sec , len (data ["latencies" ]), model ,
539
+ await print_and_save_result (args , benchmark_duration_sec , len (data ["latencies" ]), model ,
510
540
data ["latencies" ], data ["ttfts" ], data ["itls" ],
511
541
data ["tpots" ], data ["errors" ])
512
542
@@ -522,6 +552,7 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics
522
552
"num_prompts_attempted" : benchmark_result ['num_prompts_attempted' ],
523
553
"num_prompts_succeeded" : benchmark_result ['num_prompts_succeeded' ],
524
554
"request_rate" : args .request_rate ,
555
+ "queries_per_second" : benchmark_result ['queries_per_second' ],
525
556
'server_metrics' : {
526
557
** server_metrics
527
558
},
@@ -760,6 +791,7 @@ def print_metrics(metrics: List[str], duration_sec: float, namespace: str, job:
760
791
logger .debug ("HTTP Error: %s" % (response ))
761
792
continue
762
793
server_metrics [metric ] = metric_results
794
+
763
795
return server_metrics
764
796
765
797
def get_stats_for_set (name , description , points ):
@@ -783,14 +815,18 @@ def get_stats_for_set(name, description, points):
783
815
f'p99_{ name } ' : p99 ,
784
816
}
785
817
786
- def print_and_save_result (args : argparse .Namespace , benchmark_duration_sec , total_requests , model , request_latencies , ttfts , itls , tpots , errors ):
818
+ async def print_and_save_result (args : argparse .Namespace , benchmark_duration_sec , total_requests , model , request_latencies , ttfts , itls , tpots , errors ):
787
819
benchmark_result = {}
788
820
789
821
print (f"====Result for Model: { model } ====" )
790
822
print (f"Errors: { errors } " )
791
823
print (f"Total time (seconds): { benchmark_duration_sec :.2f} s" )
792
824
print (f"Successful/total requests: { len (request_latencies )} /{ total_requests } " )
793
825
print (f"Requests/sec: { total_requests / benchmark_duration_sec :.2f} " )
826
+ counter = await AsyncRequestCounter ()
827
+ queries_per_second = await counter .get_qps ()
828
+ print (f"Queries/sec: { queries_per_second :.2f} " )
829
+ benchmark_result ['queries_per_second' ] = queries_per_second
794
830
benchmark_result ["num_prompts_attempted" ] = total_requests
795
831
benchmark_result ["num_prompts_succeeded" ] = len (request_latencies )
796
832
benchmark_result ['benchmark_time' ] = benchmark_duration_sec
0 commit comments