@@ -71,6 +71,21 @@ def get_token_throughput_latencies(
71
71
req_launcher = RequestsLauncher (clients )
72
72
completed_requests = []
73
73
num_completed_requests = 0
74
+ # make up prompts outside of send loop for faster benchmarking loop
75
+ num_output_tokens_list = []
76
+ prompts = []
77
+ for i in range (max_num_completed_requests ):
78
+ num_output_tokens = (sample_random_positive_int (
79
+ mean_output_tokens , stddev_output_tokens
80
+ ))
81
+ num_output_tokens_list .append (num_output_tokens )
82
+
83
+ prompts .append (randomly_sample_sonnet_lines_prompt (
84
+ prompt_tokens_mean = mean_input_tokens ,
85
+ prompt_tokens_stddev = stddev_input_tokens ,
86
+ expect_output_tokens = num_output_tokens ,
87
+ tokenizer = tokenizer
88
+ ))
74
89
start_time = time .monotonic ()
75
90
iter = 0
76
91
pbar = tqdm (total = max_num_completed_requests )
@@ -79,21 +94,12 @@ def get_token_throughput_latencies(
79
94
and len (completed_requests ) < max_num_completed_requests
80
95
):
81
96
iter += 1
82
- num_output_tokens = sample_random_positive_int (
83
- mean_output_tokens , stddev_output_tokens
84
- )
85
-
86
- prompt = randomly_sample_sonnet_lines_prompt (
87
- prompt_tokens_mean = mean_input_tokens ,
88
- prompt_tokens_stddev = stddev_input_tokens ,
89
- expect_output_tokens = num_output_tokens ,
90
- )
91
97
92
- default_sampling_params = {"max_tokens" : num_output_tokens }
98
+ default_sampling_params = {"max_tokens" : num_output_tokens_list . pop () }
93
99
default_sampling_params .update (additional_sampling_params )
94
100
request_config = RequestConfig (
95
101
model = model ,
96
- prompt = prompt ,
102
+ prompt = prompts . pop () ,
97
103
sampling_params = default_sampling_params ,
98
104
llm_api = llm_api ,
99
105
)
0 commit comments