Skip to content

Commit f1d6bed

Browse files
authored
fix: subsequent requests cannot be sent until 'num_concurrent_requests' requests have all finished in non-block mode (#59)
* fix: subsequent requests cannot be sent until 'num_concurrent_requests' requests have all finished in non-blocking mode Signed-off-by: Sungjae Lee <[email protected]> * chore: revert missing part --------- Signed-off-by: Sungjae Lee <[email protected]>
1 parent 03872a4 commit f1d6bed

File tree

1 file changed

+59
-43
lines changed

1 file changed

+59
-43
lines changed

token_benchmark_ray.py

+59-43
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import threading
12
import argparse
23
from collections.abc import Iterable
34
import json
@@ -67,8 +68,7 @@ def get_token_throughput_latencies(
6768
if not additional_sampling_params:
6869
additional_sampling_params = {}
6970

70-
clients = construct_clients(llm_api=llm_api, num_clients=num_concurrent_requests)
71-
req_launcher = RequestsLauncher(clients)
71+
completed_requests_lock = threading.Lock()
7272
completed_requests = []
7373
num_completed_requests = 0
7474
# make up prompts outside of send loop for faster benchmarking loop
@@ -87,65 +87,81 @@ def get_token_throughput_latencies(
8787
tokenizer=tokenizer
8888
))
8989
start_time = time.monotonic()
90-
iter = 0
9190
pbar = tqdm(total=max_num_completed_requests)
92-
while (
93-
time.monotonic() - start_time < test_timeout_s
94-
and len(completed_requests) < max_num_completed_requests
95-
):
96-
iter += 1
97-
98-
default_sampling_params = {"max_tokens": num_output_tokens_list.pop()}
99-
default_sampling_params.update(additional_sampling_params)
100-
request_config = RequestConfig(
101-
model=model,
102-
prompt=prompts.pop(),
103-
sampling_params=default_sampling_params,
104-
llm_api=llm_api,
105-
)
106-
req_launcher.launch_requests(request_config)
107-
# Retrieving results less frequently allows for more concurrent requests
108-
# to be launched. This will overall reduce the amount of time it takes
109-
# for the test to run.
110-
if not (iter % num_concurrent_requests):
91+
92+
def launch_request(thread_index):
93+
nonlocal num_completed_requests
94+
clients = construct_clients(llm_api=llm_api, num_clients=1)
95+
req_launcher = RequestsLauncher(clients)
96+
request_index = thread_index % max_num_completed_requests
97+
98+
while (
99+
time.monotonic() - start_time < test_timeout_s
100+
and num_completed_requests < max_num_completed_requests
101+
):
102+
103+
default_sampling_params = {"max_tokens": num_output_tokens_list[request_index] }
104+
default_sampling_params.update(additional_sampling_params)
105+
request_config = RequestConfig(
106+
model=model,
107+
prompt=prompts[request_index],
108+
sampling_params=default_sampling_params,
109+
llm_api=llm_api,
110+
)
111+
req_launcher.launch_requests(request_config)
112+
111113
outs = req_launcher.get_next_ready()
112114
all_metrics = []
113115
for out in outs:
114116
request_metrics, gen_text, _ = out
115117
num_output_tokens = get_token_length(gen_text)
116-
if num_output_tokens:
117-
request_metrics[common_metrics.INTER_TOKEN_LAT] /= num_output_tokens
118-
else:
119-
request_metrics[common_metrics.INTER_TOKEN_LAT] = 0
120-
request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens
121-
request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens
122-
request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT]
123-
all_metrics.append(request_metrics)
124-
completed_requests.extend(all_metrics)
125-
pbar.update(len(completed_requests) - num_completed_requests)
126-
num_completed_requests = len(completed_requests)
118+
with completed_requests_lock:
119+
if num_completed_requests < max_num_completed_requests:
120+
if num_output_tokens:
121+
request_metrics[common_metrics.INTER_TOKEN_LAT] /= request_metrics[common_metrics.NUM_OUTPUT_TOKENS]
122+
else:
123+
request_metrics[common_metrics.INTER_TOKEN_LAT] = 0
124+
request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens
125+
request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens
126+
request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT]
127+
all_metrics.append(request_metrics)
128+
completed_requests.extend(all_metrics)
129+
pbar.update(len(all_metrics))
130+
num_completed_requests += len(all_metrics)
131+
request_index = (request_index + num_concurrent_requests) % max_num_completed_requests
132+
133+
threads = []
134+
for i in range(num_concurrent_requests):
135+
thread = threading.Thread(target=launch_request, args=(i,))
136+
threads.append(thread)
137+
thread.start()
138+
139+
for thread in threads:
140+
thread.join()
127141

128142
pbar.close()
129143
end_time = time.monotonic()
130144
if end_time - start_time >= test_timeout_s:
131145
print("Test timed out before all requests could be completed.")
132146

133147
# check one last time that there are no remaining results to collect.
148+
clients = construct_clients(llm_api=llm_api, num_clients=1)
149+
req_launcher = RequestsLauncher(clients)
134150
outs = req_launcher.get_next_ready()
135151
all_metrics = []
136152
for out in outs:
137153
request_metrics, gen_text, _ = out
138154
num_output_tokens = get_token_length(gen_text)
139-
if num_output_tokens:
140-
request_metrics[common_metrics.INTER_TOKEN_LAT] /= num_output_tokens
141-
else:
142-
request_metrics[common_metrics.INTER_TOKEN_LAT] = 0
143-
request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens
144-
request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens
145-
request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT]
146-
147-
all_metrics.append(request_metrics)
148-
completed_requests.extend(all_metrics)
155+
with completed_requests_lock:
156+
if num_completed_requests < max_num_completed_requests:
157+
if num_output_tokens:
158+
request_metrics[common_metrics.INTER_TOKEN_LAT] /= num_output_tokens
159+
else:
160+
request_metrics[common_metrics.INTER_TOKEN_LAT] = 0
161+
request_metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens
162+
request_metrics[common_metrics.NUM_TOTAL_TOKENS] = request_metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens
163+
request_metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = num_output_tokens / request_metrics[common_metrics.E2E_LAT]
164+
completed_requests.extend(request_metrics)
149165

150166
print(f"\Results for token benchmark for {model} queried with the {llm_api} api.\n")
151167
ret = metrics_summary(completed_requests, start_time, end_time)

0 commit comments

Comments
 (0)