1
+ import threading
1
2
import argparse
2
3
from collections .abc import Iterable
3
4
import json
@@ -67,8 +68,7 @@ def get_token_throughput_latencies(
67
68
if not additional_sampling_params :
68
69
additional_sampling_params = {}
69
70
70
- clients = construct_clients (llm_api = llm_api , num_clients = num_concurrent_requests )
71
- req_launcher = RequestsLauncher (clients )
71
+ completed_requests_lock = threading .Lock ()
72
72
completed_requests = []
73
73
num_completed_requests = 0
74
74
# make up prompts outside of send loop for faster benchmarking loop
@@ -87,65 +87,81 @@ def get_token_throughput_latencies(
87
87
tokenizer = tokenizer
88
88
))
89
89
start_time = time .monotonic ()
90
- iter = 0
91
90
pbar = tqdm (total = max_num_completed_requests )
92
- while (
93
- time .monotonic () - start_time < test_timeout_s
94
- and len (completed_requests ) < max_num_completed_requests
95
- ):
96
- iter += 1
97
-
98
- default_sampling_params = {"max_tokens" : num_output_tokens_list .pop ()}
99
- default_sampling_params .update (additional_sampling_params )
100
- request_config = RequestConfig (
101
- model = model ,
102
- prompt = prompts .pop (),
103
- sampling_params = default_sampling_params ,
104
- llm_api = llm_api ,
105
- )
106
- req_launcher .launch_requests (request_config )
107
- # Retrieving results less frequently allows for more concurrent requests
108
- # to be launched. This will overall reduce the amount of time it takes
109
- # for the test to run.
110
- if not (iter % num_concurrent_requests ):
91
+
92
+ def launch_request (thread_index ):
93
+ nonlocal num_completed_requests
94
+ clients = construct_clients (llm_api = llm_api , num_clients = 1 )
95
+ req_launcher = RequestsLauncher (clients )
96
+ request_index = thread_index % max_num_completed_requests
97
+
98
+ while (
99
+ time .monotonic () - start_time < test_timeout_s
100
+ and num_completed_requests < max_num_completed_requests
101
+ ):
102
+
103
+ default_sampling_params = {"max_tokens" : num_output_tokens_list [request_index ] }
104
+ default_sampling_params .update (additional_sampling_params )
105
+ request_config = RequestConfig (
106
+ model = model ,
107
+ prompt = prompts [request_index ],
108
+ sampling_params = default_sampling_params ,
109
+ llm_api = llm_api ,
110
+ )
111
+ req_launcher .launch_requests (request_config )
112
+
111
113
outs = req_launcher .get_next_ready ()
112
114
all_metrics = []
113
115
for out in outs :
114
116
request_metrics , gen_text , _ = out
115
117
num_output_tokens = get_token_length (gen_text )
116
- if num_output_tokens :
117
- request_metrics [common_metrics .INTER_TOKEN_LAT ] /= num_output_tokens
118
- else :
119
- request_metrics [common_metrics .INTER_TOKEN_LAT ] = 0
120
- request_metrics [common_metrics .NUM_OUTPUT_TOKENS ] = num_output_tokens
121
- request_metrics [common_metrics .NUM_TOTAL_TOKENS ] = request_metrics [common_metrics .NUM_INPUT_TOKENS ] + num_output_tokens
122
- request_metrics [common_metrics .REQ_OUTPUT_THROUGHPUT ] = num_output_tokens / request_metrics [common_metrics .E2E_LAT ]
123
- all_metrics .append (request_metrics )
124
- completed_requests .extend (all_metrics )
125
- pbar .update (len (completed_requests ) - num_completed_requests )
126
- num_completed_requests = len (completed_requests )
118
+ with completed_requests_lock :
119
+ if num_completed_requests < max_num_completed_requests :
120
+ if num_output_tokens :
121
+ request_metrics [common_metrics .INTER_TOKEN_LAT ] /= request_metrics [common_metrics .NUM_OUTPUT_TOKENS ]
122
+ else :
123
+ request_metrics [common_metrics .INTER_TOKEN_LAT ] = 0
124
+ request_metrics [common_metrics .NUM_OUTPUT_TOKENS ] = num_output_tokens
125
+ request_metrics [common_metrics .NUM_TOTAL_TOKENS ] = request_metrics [common_metrics .NUM_INPUT_TOKENS ] + num_output_tokens
126
+ request_metrics [common_metrics .REQ_OUTPUT_THROUGHPUT ] = num_output_tokens / request_metrics [common_metrics .E2E_LAT ]
127
+ all_metrics .append (request_metrics )
128
+ completed_requests .extend (all_metrics )
129
+ pbar .update (len (all_metrics ))
130
+ num_completed_requests += len (all_metrics )
131
+ request_index = (request_index + num_concurrent_requests ) % max_num_completed_requests
132
+
133
+ threads = []
134
+ for i in range (num_concurrent_requests ):
135
+ thread = threading .Thread (target = launch_request , args = (i ,))
136
+ threads .append (thread )
137
+ thread .start ()
138
+
139
+ for thread in threads :
140
+ thread .join ()
127
141
128
142
pbar .close ()
129
143
end_time = time .monotonic ()
130
144
if end_time - start_time >= test_timeout_s :
131
145
print ("Test timed out before all requests could be completed." )
132
146
133
147
# check one last time that there are no remaining results to collect.
148
+ clients = construct_clients (llm_api = llm_api , num_clients = 1 )
149
+ req_launcher = RequestsLauncher (clients )
134
150
outs = req_launcher .get_next_ready ()
135
151
all_metrics = []
136
152
for out in outs :
137
153
request_metrics , gen_text , _ = out
138
154
num_output_tokens = get_token_length (gen_text )
139
- if num_output_tokens :
140
- request_metrics [ common_metrics . INTER_TOKEN_LAT ] /= num_output_tokens
141
- else :
142
- request_metrics [common_metrics .INTER_TOKEN_LAT ] = 0
143
- request_metrics [ common_metrics . NUM_OUTPUT_TOKENS ] = num_output_tokens
144
- request_metrics [ common_metrics . NUM_TOTAL_TOKENS ] = request_metrics [common_metrics .NUM_INPUT_TOKENS ] + num_output_tokens
145
- request_metrics [ common_metrics . REQ_OUTPUT_THROUGHPUT ] = num_output_tokens / request_metrics [common_metrics .E2E_LAT ]
146
-
147
- all_metrics . append ( request_metrics )
148
- completed_requests .extend (all_metrics )
155
+ with completed_requests_lock :
156
+ if num_completed_requests < max_num_completed_requests :
157
+ if num_output_tokens :
158
+ request_metrics [common_metrics .INTER_TOKEN_LAT ] /= num_output_tokens
159
+ else :
160
+ request_metrics [common_metrics .INTER_TOKEN_LAT ] = 0
161
+ request_metrics [common_metrics .NUM_OUTPUT_TOKENS ] = num_output_tokens
162
+ request_metrics [ common_metrics . NUM_TOTAL_TOKENS ] = request_metrics [ common_metrics . NUM_INPUT_TOKENS ] + num_output_tokens
163
+ request_metrics [ common_metrics . REQ_OUTPUT_THROUGHPUT ] = num_output_tokens / request_metrics [ common_metrics . E2E_LAT ]
164
+ completed_requests .extend (request_metrics )
149
165
150
166
print (f"\Results for token benchmark for { model } queried with the { llm_api } api.\n " )
151
167
ret = metrics_summary (completed_requests , start_time , end_time )
0 commit comments