2
2
3
3
import argparse
4
4
import json
5
+ import os
6
+ import random
5
7
import subprocess
6
8
from time import sleep , time
7
- from typing import Optional
9
+ from typing import Optional , Union
8
10
9
11
import datasets
10
12
import logging
18
20
logger = logging .getLogger ("server-bench" )
19
21
20
22
21
- def get_prompts (n_prompts : int ) -> list [str ]:
22
- logger .info ("Loading MMLU dataset..." )
23
- ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
23
+ def get_prompts_text (dataset_name : str , n_prompts : int ) -> Optional [list [str ]]:
24
+ ret = []
25
+ if dataset_name .lower () == "mmlu" :
26
+ logger .info ("Loading MMLU dataset..." )
27
+ ret = datasets .load_dataset ("cais/mmlu" , "all" )["test" ]["question" ] # type: ignore
28
+ else :
29
+ return None
24
30
if n_prompts >= 0 :
25
31
ret = ret [:n_prompts ]
26
32
return ret
27
33
28
34
29
- def get_server (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int ) -> dict :
35
+ def get_prompt_lengths_rng (n_prompts : int , prompt_length_min : int , prompt_length_max : int ) -> list [int ]:
36
+ assert n_prompts >= 0
37
+ ret : list [int ] = []
38
+ for i in range (n_prompts ):
39
+ random .seed (13 * i + 0 )
40
+ ret .append (random .randint (prompt_length_min , prompt_length_max ))
41
+ return ret
42
+
43
+
44
+ def get_prompts_rng (prompt_lengths : list [int ]) -> list [list [int ]]:
45
+ return [[random .randint (100 , 10000 ) for _ in range (pl )] for pl in prompt_lengths ]
46
+
47
+
48
+ def get_server (path_server : str , path_log : Optional [str ]) -> dict :
30
49
logger .info ("Starting the llama.cpp server..." )
31
- address = f"http://localhost:{ port } "
32
-
33
- popen_args : list [str ] = [
34
- path_server ,
35
- "--flash-attn" ,
36
- "--n-gpu-layers" , str (n_gpu_layers ),
37
- "--parallel" , str (parallel ),
38
- "--ctx-size" , str (parallel * ctx_size ),
39
- "--model" , path_model ,
40
- "--port" , str (port ),
41
- "--swa-full" , # FIXME performance bad otherwise
42
- # "--attn-streams",
43
- ]
44
- fout = open ("bench.log" , "w" ) if path_log is not None else subprocess .DEVNULL
45
- process = subprocess .Popen (popen_args , stdout = fout , stderr = subprocess .STDOUT )
50
+ hostname : str = os .environ .get ("LLAMA_ARG_HOST" , "127.0.0.1" )
51
+ port : str = os .environ .get ("LLAMA_ARG_PORT" , "8080" )
52
+ address : str = f"http://{ hostname } :{ port } "
53
+
54
+ fout = open (path_log , "w" ) if path_log is not None else subprocess .DEVNULL
55
+ process = subprocess .Popen ([path_server ], stdout = fout , stderr = subprocess .STDOUT )
46
56
47
57
n_failures : int = 0
48
58
while True :
49
59
try :
50
60
sleep (1.0 )
51
61
exit_code = process .poll ()
52
62
if exit_code is not None :
53
- raise RuntimeError (f"llama.cpp server for { path_model } exited unexpectedly with exit code { exit_code } " )
63
+ raise RuntimeError (f"llama.cpp server exited unexpectedly with exit code { exit_code } , see { path_log } " )
54
64
response = requests .get (f"{ address } /health" )
55
65
if response .status_code == 200 :
56
66
break
57
67
except requests .ConnectionError :
58
68
n_failures += 1
59
69
if n_failures >= 10 :
60
- raise RuntimeError (f "llama.cpp server for { path_model } is not healthy after 10 seconds" )
70
+ raise RuntimeError ("llama.cpp server is not healthy after 10 seconds" )
61
71
62
72
return {"process" : process , "address" : address , "fout" : fout }
63
73
@@ -87,76 +97,116 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
87
97
session = data ["session" ]
88
98
server_address : str = data ["server_address" ]
89
99
90
- response = session .post (
91
- f"{ server_address } /apply-template" ,
92
- json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
93
- )
94
- if response .status_code != 200 :
95
- raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
96
- prompt : str = json .loads (response .text )["prompt" ]
97
-
98
- json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
99
- response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100
+ t_submit = time ()
101
+ if data ["synthetic_prompt" ]:
102
+ json_data : dict = {
103
+ "prompt" : data ["prompt" ], "ignore_eos" : True , "cache_prompt" : False ,
104
+ "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
105
+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
106
+ else :
107
+ response = session .post (
108
+ f"{ server_address } /apply-template" ,
109
+ json = {"messages" : [{"role" : "user" , "content" : data ["prompt" ], "stream" : True }]}
110
+ )
111
+ if response .status_code != 200 :
112
+ raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
113
+ prompt : str = json .loads (response .text )["prompt" ]
114
+
115
+ json_data : dict = {"prompt" : prompt , "seed" : data ["seed" ], "n_predict" : data ["n_predict" ], "stream" : True }
116
+ response = session .post (f"{ server_address } /completion" , json = json_data , stream = True )
100
117
101
- last_valid_line : str = ""
102
118
token_arrival_times : list [float ] = []
103
- for line in response .iter_lines (decode_unicode = True ):
104
- if not line .startswith ("data: " ):
119
+ for line in response .iter_lines (decode_unicode = False ):
120
+ if not line .startswith (b "data: " ):
105
121
continue
106
- last_valid_line = line
107
122
token_arrival_times .append (time ())
108
123
token_arrival_times = token_arrival_times [:- 1 ]
109
124
110
125
if response .status_code != 200 :
111
126
raise RuntimeError (f"Server returned status code { response .status_code } : { response .text } " )
112
- timings : dict = json .loads (last_valid_line [6 :])["timings" ]
113
127
114
- return (timings ["prompt_ms" ], token_arrival_times )
115
-
116
-
117
- def benchmark (path_server : str , path_model : str , path_log : Optional [str ], port : int , n_gpu_layers : int , parallel : int , ctx_size : int , n_prompts : int , n_predict : int ):
118
- num_workers : int = parallel + 1
119
- prompts : list [str ] = get_prompts (n_prompts )
128
+ return (t_submit , token_arrival_times )
129
+
130
+
131
+ def benchmark (path_server : str , path_log : Optional [str ], prompt_source : str , n_prompts : int , n_predict : int , n_predict_min : int ):
132
+ if os .environ .get ("LLAMA_ARG_N_PARALLEL" ) is None :
133
+ logger .info ("LLAMA_ARG_N_PARALLEL not explicitly set, using 32" )
134
+ os .environ ["LLAMA_ARG_N_PARALLEL" ] = "32"
135
+ if os .environ .get ("LLAMA_ARG_N_GPU_LAYERS" ) is None :
136
+ logger .info ("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999" )
137
+ os .environ ["LLAMA_ARG_N_GPU_LAYERS" ] = "999"
138
+ if os .environ .get ("LLAMA_ARG_FLASH_ATTN" ) is None :
139
+ logger .info ("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'" )
140
+ os .environ ["LLAMA_ARG_FLASH_ATTN" ] = "true"
141
+
142
+ parallel : int = int (os .environ .get ("LLAMA_ARG_N_PARALLEL" , 1 ))
143
+ prompts : Union [None , list [str ], list [list [int ]]] = get_prompts_text (prompt_source , n_prompts )
144
+ synthetic_prompts : bool = prompts is None
145
+ prompt_n = []
146
+
147
+ if synthetic_prompts :
148
+ prompt_source_split : list [str ] = prompt_source .split ("-" )
149
+ assert len (prompt_source_split ) == 3
150
+ assert prompt_source_split [0 ].lower () == "rng"
151
+ prompt_length_min : int = int (prompt_source_split [1 ])
152
+ prompt_length_max : int = int (prompt_source_split [2 ])
153
+ logger .info ("Generating random prompts..." )
154
+ prompt_n = get_prompt_lengths_rng (n_prompts , prompt_length_min , prompt_length_max )
155
+ prompts = get_prompts_rng (prompt_n )
156
+ else :
157
+ n_predict_min = n_predict
158
+
159
+ if os .environ .get ("LLAMA_ARG_CTX_SIZE" ) is None :
160
+ context_per_slot : int = int (1.05 * (n_predict + (np .max (prompt_n ) if synthetic_prompts else 2048 )))
161
+ context_total : int = context_per_slot * parallel
162
+ os .environ ["LLAMA_ARG_CTX_SIZE" ] = str (context_total )
163
+ logger .info (f"LLAMA_ARG_CTX_SIZE not explicitly set, using { context_total } ({ context_per_slot } per slot)." )
120
164
121
165
server : Optional [dict ] = None
122
166
session = None
123
167
try :
124
- server = get_server (path_server , path_model , path_log , port , n_gpu_layers , parallel , ctx_size )
168
+ server = get_server (path_server , path_log )
125
169
server_address : str = server ["address" ]
126
170
127
- adapter = requests .adapters .HTTPAdapter (pool_connections = num_workers , pool_maxsize = num_workers ) # type: ignore
171
+ adapter = requests .adapters .HTTPAdapter (pool_connections = parallel , pool_maxsize = parallel ) # type: ignore
128
172
session = requests .Session ()
129
173
session .mount ("http://" , adapter )
130
174
session .mount ("https://" , adapter )
131
175
132
176
data : list [dict ] = []
177
+
133
178
for i , p in enumerate (prompts ):
134
- data .append ({"session" : session , "server_address" : server_address , "prompt" : p , "n_predict" : n_predict , "seed" : i })
179
+ random .seed (13 * i + 1 )
180
+ data .append ({
181
+ "session" : session , "server_address" : server_address , "prompt" : p , "synthetic_prompt" : synthetic_prompts ,
182
+ "n_predict" : random .randint (n_predict_min , n_predict ), "seed" : 13 * i + 2 })
135
183
136
- logger .info ("Getting the prompt lengths..." )
137
- prompt_n = [get_prompt_length (d ) for d in data ]
184
+ if not synthetic_prompts :
185
+ logger .info ("Getting the prompt lengths..." )
186
+ prompt_n = [get_prompt_length (d ) for d in data ]
138
187
139
188
logger .info ("Starting the benchmark...\n " )
140
189
t0 = time ()
141
- results : list [tuple [int , list [float ]]] = thread_map (send_prompt , data , max_workers = num_workers , chunksize = 1 )
190
+ results : list [tuple [float , list [float ]]] = thread_map (send_prompt , data , max_workers = parallel , chunksize = 1 )
142
191
finally :
143
192
if server is not None :
144
193
server ["process" ].terminate ()
145
194
server ["process" ].wait ()
146
195
if session is not None :
147
196
session .close ()
148
197
149
- prompt_ms = []
198
+ prompt_t = []
150
199
token_t = []
151
200
depth_sum : int = 0
152
- for pn , (pms , tat ) in zip (prompt_n , results ):
153
- prompt_ms .append (pms )
201
+ for pn , (t_submit , tat ) in zip (prompt_n , results ):
202
+ prompt_t .append (tat [ 0 ] - t_submit )
154
203
token_t += tat
155
204
n_tokens : int = len (tat )
156
205
depth_sum += n_tokens * pn
157
206
depth_sum += n_tokens * (n_tokens + 1 ) // 2
207
+ assert len (token_t ) > 0
158
208
prompt_n = np .array (prompt_n , dtype = np .int64 )
159
- prompt_ms = np .array (prompt_ms , dtype = np .float64 )
209
+ prompt_t = np .array (prompt_t , dtype = np .float64 )
160
210
token_t = np .array (token_t , dtype = np .float64 )
161
211
162
212
token_t -= t0
@@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167
217
logger .info (f"Request throughput: { n_prompts / token_t_last :.2f} requests/s = { n_prompts / (token_t_last / 60 ):.2f} requests/min" )
168
218
logger .info (f"Total prompt length: { np .sum (prompt_n )} tokens" )
169
219
logger .info (f"Average prompt length: { np .mean (prompt_n ):.2f} tokens" )
170
- logger .info (f"Average prompt latency: { np .mean (prompt_ms ):.2f} ms" )
171
- logger .info (f"Average prompt speed: { np .sum (prompt_n ) / ( 1e-3 * np .sum (prompt_ms ) ):.2f} tokens/s" )
220
+ logger .info (f"Average prompt latency: { 1e3 * np .mean (prompt_t ):.2f} ms" )
221
+ logger .info (f"Average prompt speed: { np .sum (prompt_n ) / np .sum (prompt_t ):.2f} tokens/s" )
172
222
logger .info (f"Total generated tokens: { token_t .shape [0 ]} " )
173
223
logger .info (f"Average generation depth: { depth_sum / token_t .shape [0 ]:.2f} tokens" )
174
224
logger .info (f"Average total generation speed: { token_t .shape [0 ] / token_t_last :.2f} tokens/s" )
175
225
logger .info (f"Average generation speed per slot: { token_t .shape [0 ] / (parallel * token_t_last ):.2f} tokens/s / slot" )
226
+ logger .info ("" )
227
+ logger .info (
228
+ "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
229
+ "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." )
176
230
177
231
plt .figure ()
178
- plt .scatter (prompt_n , prompt_ms , s = 10.0 , marker = "." , alpha = 0.25 )
179
- plt .xlim (0 , 1.05 * np .max (prompt_n ))
180
- plt .ylim (0 , 1.05 * np .max (prompt_ms ))
181
- plt .title (path_model )
232
+ plt .scatter (prompt_n , 1e3 * prompt_t , s = 10.0 , marker = "." , alpha = 0.25 )
233
+ plt .xlim (0 , 1.05e0 * np .max (prompt_n ))
234
+ plt .ylim (0 , 1.05e3 * np .max (prompt_t ))
182
235
plt .xlabel ("Prompt length [tokens]" )
183
236
plt .ylabel ("Time to first token [ms]" )
184
237
plt .savefig ("prompt_time.png" , dpi = 240 )
@@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187
240
plt .figure ()
188
241
plt .hist (token_t , np .arange (0 , bin_max ))
189
242
plt .xlim (0 , bin_max + 1 )
190
- plt .title (path_model )
191
243
plt .xlabel ("Time [s]" )
192
244
plt .ylabel ("Num. tokens generated per second" )
193
245
plt .savefig ("gen_rate.png" , dpi = 240 )
@@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196
248
if __name__ == "__main__" :
197
249
parser = argparse .ArgumentParser (
198
250
description = "Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199
- "Results are printed to console and visualized as plots (saved to current working directory)." )
251
+ "Results are printed to console and visualized as plots (saved to current working directory). "
252
+ "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help)." )
200
253
parser .add_argument ("--path_server" , type = str , default = "llama-server" , help = "Path to the llama.cpp server binary" )
201
- parser .add_argument ("--path_model" , type = str , required = True , help = "Path to the model to use for the benchmark" )
202
- parser .add_argument ("--path_log" , type = str , default = None , help = "Path to the model to use for the benchmark" )
203
- parser .add_argument ("--port" , type = int , default = 18725 , help = "Port to use for the server during the benchmark" )
204
- parser .add_argument ("--n_gpu_layers" , type = int , default = 999 , help = "Number of GPU layers for the server" )
205
- parser .add_argument ("--parallel" , type = int , default = 16 , help = "Number of slots for the server" )
206
- parser .add_argument ("--ctx_size" , type = int , default = 4096 , help = "Server context size per slot" )
207
- parser .add_argument ("--n_prompts" , type = int , default = 1000 , help = "Number of prompts to evaluate" )
254
+ parser .add_argument ("--path_log" , type = str , default = "server-bench.log" , help = "Path to the model to use for the benchmark" )
255
+ parser .add_argument (
256
+ "--prompt_source" , type = str , default = "rng-1024-2048" ,
257
+ help = "How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
258
+ "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]" )
259
+ parser .add_argument ("--n_prompts" , type = int , default = 100 , help = "Number of prompts to evaluate" )
208
260
parser .add_argument ("--n_predict" , type = int , default = 2048 , help = "Max. number of tokens to predict per prompt" )
261
+ parser .add_argument (
262
+ "--n_predict_min" , type = int , default = 1024 ,
263
+ help = "Min. number of tokens to predict per prompt (supported for synthetic prompts only)" )
209
264
args = parser .parse_args ()
210
265
benchmark (** vars (args ))
0 commit comments