Skip to content

Commit 5cae766

Browse files
scripts: synthetic prompt mode for server-bench.py (ggml-org#14695)
1 parent 4b91d6f commit 5cae766

File tree

2 files changed

+124
-69
lines changed

2 files changed

+124
-69
lines changed

scripts/server-bench.py

100644100755
Lines changed: 123 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import argparse
44
import json
5+
import os
6+
import random
57
import subprocess
68
from time import sleep, time
7-
from typing import Optional
9+
from typing import Optional, Union
810

911
import datasets
1012
import logging
@@ -18,46 +20,54 @@
1820
logger = logging.getLogger("server-bench")
1921

2022

21-
def get_prompts(n_prompts: int) -> list[str]:
22-
logger.info("Loading MMLU dataset...")
23-
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
23+
def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
24+
ret = []
25+
if dataset_name.lower() == "mmlu":
26+
logger.info("Loading MMLU dataset...")
27+
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
28+
else:
29+
return None
2430
if n_prompts >= 0:
2531
ret = ret[:n_prompts]
2632
return ret
2733

2834

29-
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
35+
def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
36+
assert n_prompts >= 0
37+
ret: list[int] = []
38+
for i in range(n_prompts):
39+
random.seed(13 * i + 0)
40+
ret.append(random.randint(prompt_length_min, prompt_length_max))
41+
return ret
42+
43+
44+
def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
45+
return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
46+
47+
48+
def get_server(path_server: str, path_log: Optional[str]) -> dict:
3049
logger.info("Starting the llama.cpp server...")
31-
address = f"http://localhost:{port}"
32-
33-
popen_args: list[str] = [
34-
path_server,
35-
"--flash-attn",
36-
"--n-gpu-layers", str(n_gpu_layers),
37-
"--parallel", str(parallel),
38-
"--ctx-size", str(parallel * ctx_size),
39-
"--model", path_model,
40-
"--port", str(port),
41-
"--swa-full", # FIXME performance bad otherwise
42-
# "--attn-streams",
43-
]
44-
fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
45-
process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
50+
hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
51+
port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
52+
address: str = f"http://{hostname}:{port}"
53+
54+
fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
55+
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
4656

4757
n_failures: int = 0
4858
while True:
4959
try:
5060
sleep(1.0)
5161
exit_code = process.poll()
5262
if exit_code is not None:
53-
raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
63+
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
5464
response = requests.get(f"{address}/health")
5565
if response.status_code == 200:
5666
break
5767
except requests.ConnectionError:
5868
n_failures += 1
5969
if n_failures >= 10:
60-
raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
70+
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
6171

6272
return {"process": process, "address": address, "fout": fout}
6373

@@ -87,76 +97,116 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
8797
session = data["session"]
8898
server_address: str = data["server_address"]
8999

90-
response = session.post(
91-
f"{server_address}/apply-template",
92-
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
93-
)
94-
if response.status_code != 200:
95-
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
96-
prompt: str = json.loads(response.text)["prompt"]
97-
98-
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
99-
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
100+
t_submit = time()
101+
if data["synthetic_prompt"]:
102+
json_data: dict = {
103+
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
104+
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
105+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
106+
else:
107+
response = session.post(
108+
f"{server_address}/apply-template",
109+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
110+
)
111+
if response.status_code != 200:
112+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
113+
prompt: str = json.loads(response.text)["prompt"]
114+
115+
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
116+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
100117

101-
last_valid_line: str = ""
102118
token_arrival_times: list[float] = []
103-
for line in response.iter_lines(decode_unicode=True):
104-
if not line.startswith("data: "):
119+
for line in response.iter_lines(decode_unicode=False):
120+
if not line.startswith(b"data: "):
105121
continue
106-
last_valid_line = line
107122
token_arrival_times.append(time())
108123
token_arrival_times = token_arrival_times[:-1]
109124

110125
if response.status_code != 200:
111126
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
112-
timings: dict = json.loads(last_valid_line[6:])["timings"]
113127

114-
return (timings["prompt_ms"], token_arrival_times)
115-
116-
117-
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
118-
num_workers: int = parallel + 1
119-
prompts: list[str] = get_prompts(n_prompts)
128+
return (t_submit, token_arrival_times)
129+
130+
131+
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
132+
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
133+
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
134+
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
135+
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
136+
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
137+
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
138+
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
139+
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
140+
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
141+
142+
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
143+
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
144+
synthetic_prompts: bool = prompts is None
145+
prompt_n = []
146+
147+
if synthetic_prompts:
148+
prompt_source_split: list[str] = prompt_source.split("-")
149+
assert len(prompt_source_split) == 3
150+
assert prompt_source_split[0].lower() == "rng"
151+
prompt_length_min: int = int(prompt_source_split[1])
152+
prompt_length_max: int = int(prompt_source_split[2])
153+
logger.info("Generating random prompts...")
154+
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
155+
prompts = get_prompts_rng(prompt_n)
156+
else:
157+
n_predict_min = n_predict
158+
159+
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
160+
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
161+
context_total: int = context_per_slot * parallel
162+
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
163+
logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
120164

121165
server: Optional[dict] = None
122166
session = None
123167
try:
124-
server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
168+
server = get_server(path_server, path_log)
125169
server_address: str = server["address"]
126170

127-
adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore
171+
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
128172
session = requests.Session()
129173
session.mount("http://", adapter)
130174
session.mount("https://", adapter)
131175

132176
data: list[dict] = []
177+
133178
for i, p in enumerate(prompts):
134-
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
179+
random.seed(13 * i + 1)
180+
data.append({
181+
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
182+
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
135183

136-
logger.info("Getting the prompt lengths...")
137-
prompt_n = [get_prompt_length(d) for d in data]
184+
if not synthetic_prompts:
185+
logger.info("Getting the prompt lengths...")
186+
prompt_n = [get_prompt_length(d) for d in data]
138187

139188
logger.info("Starting the benchmark...\n")
140189
t0 = time()
141-
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
190+
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
142191
finally:
143192
if server is not None:
144193
server["process"].terminate()
145194
server["process"].wait()
146195
if session is not None:
147196
session.close()
148197

149-
prompt_ms = []
198+
prompt_t = []
150199
token_t = []
151200
depth_sum: int = 0
152-
for pn, (pms, tat) in zip(prompt_n, results):
153-
prompt_ms.append(pms)
201+
for pn, (t_submit, tat) in zip(prompt_n, results):
202+
prompt_t.append(tat[0] - t_submit)
154203
token_t += tat
155204
n_tokens: int = len(tat)
156205
depth_sum += n_tokens * pn
157206
depth_sum += n_tokens * (n_tokens + 1) // 2
207+
assert len(token_t) > 0
158208
prompt_n = np.array(prompt_n, dtype=np.int64)
159-
prompt_ms = np.array(prompt_ms, dtype=np.float64)
209+
prompt_t = np.array(prompt_t, dtype=np.float64)
160210
token_t = np.array(token_t, dtype=np.float64)
161211

162212
token_t -= t0
@@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
167217
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
168218
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
169219
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
170-
logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
171-
logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
220+
logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
221+
logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
172222
logger.info(f"Total generated tokens: {token_t.shape[0]}")
173223
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
174224
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
175225
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
226+
logger.info("")
227+
logger.info(
228+
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
229+
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
176230

177231
plt.figure()
178-
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
179-
plt.xlim(0, 1.05 * np.max(prompt_n))
180-
plt.ylim(0, 1.05 * np.max(prompt_ms))
181-
plt.title(path_model)
232+
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
233+
plt.xlim(0, 1.05e0 * np.max(prompt_n))
234+
plt.ylim(0, 1.05e3 * np.max(prompt_t))
182235
plt.xlabel("Prompt length [tokens]")
183236
plt.ylabel("Time to first token [ms]")
184237
plt.savefig("prompt_time.png", dpi=240)
@@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
187240
plt.figure()
188241
plt.hist(token_t, np.arange(0, bin_max))
189242
plt.xlim(0, bin_max + 1)
190-
plt.title(path_model)
191243
plt.xlabel("Time [s]")
192244
plt.ylabel("Num. tokens generated per second")
193245
plt.savefig("gen_rate.png", dpi=240)
@@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
196248
if __name__ == "__main__":
197249
parser = argparse.ArgumentParser(
198250
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199-
"Results are printed to console and visualized as plots (saved to current working directory).")
251+
"Results are printed to console and visualized as plots (saved to current working directory). "
252+
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
200253
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
201-
parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
202-
parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
203-
parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
204-
parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
205-
parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
206-
parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
207-
parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
254+
parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
255+
parser.add_argument(
256+
"--prompt_source", type=str, default="rng-1024-2048",
257+
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
258+
"rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
259+
parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
208260
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
261+
parser.add_argument(
262+
"--n_predict_min", type=int, default=1024,
263+
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
209264
args = parser.parse_args()
210265
benchmark(**vars(args))

tools/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
77
**Features:**
88
* LLM inference of F16 and quantized models on GPU and CPU
99
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
10-
* Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510)
10+
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
1111
* Parallel decoding with multi-user support
1212
* Continuous batching
1313
* Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support

0 commit comments

Comments
 (0)