Skip to content

Commit c20dc54

Browse files
fix CI
1 parent 8e8b5e4 commit c20dc54

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

requirements/requirements-server-bench.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
datasets~=3.6.0
1+
datasets~=3.2.0
22
matplotlib~=3.10.0
33
numpy~=1.26.4
44
requests~=2.32.3

scripts/server-bench.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77
from typing import Optional
88

99
import datasets
10+
import logging
1011
import matplotlib.pyplot as plt
1112
import numpy as np
1213
import requests
1314
from tqdm.contrib.concurrent import thread_map
1415

1516

17+
logging.basicConfig(level=logging.INFO)
18+
logger = logging.getLogger("server-bench")
19+
20+
1621
def get_prompts(n_prompts: int) -> list[str]:
17-
print("Loading MMLU dataset...")
18-
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]
22+
logger.info(" Loading MMLU dataset...")
23+
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
1924
if n_prompts >= 0:
2025
ret = ret[:n_prompts]
2126
return ret
2227

2328

2429
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
25-
print("Starting the llama.cpp server...")
30+
logger.info(" Starting the llama.cpp server...")
2631
address = f"http://localhost:{port}"
2732

2833
popen_args: list[str] = [
@@ -78,7 +83,7 @@ def get_prompt_length(data: dict) -> int:
7883
return len(tokens)
7984

8085

81-
def send_prompt(data: dict) -> tuple[int, float, list[float]]:
86+
def send_prompt(data: dict) -> tuple[float, list[float]]:
8287
session = data["session"]
8388
server_address: str = data["server_address"]
8489

@@ -93,6 +98,7 @@ def send_prompt(data: dict) -> tuple[int, float, list[float]]:
9398
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
9499
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
95100

101+
last_valid_line: str = ""
96102
token_arrival_times: list[float] = []
97103
for line in response.iter_lines(decode_unicode=True):
98104
if not line.startswith("data: "):
@@ -111,21 +117,20 @@ def send_prompt(data: dict) -> tuple[int, float, list[float]]:
111117
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
112118
prompts: list[str] = get_prompts(n_prompts)
113119

114-
server = None
120+
server: Optional[dict] = None
115121
try:
116-
server: dict = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
122+
server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
117123
server_address: str = server["address"]
118124

119125
with requests.Session() as session:
120126
data: list[dict] = []
121127
for i, p in enumerate(prompts):
122128
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
123129

124-
print("Getting the prompt lengths...")
125-
prompt_n: list[int] = [get_prompt_length(d) for d in data]
130+
logger.info(" Getting the prompt lengths...")
131+
prompt_n = [get_prompt_length(d) for d in data]
126132

127-
print("Starting the benchmark...")
128-
print()
133+
logger.info(" Starting the benchmark...\n")
129134
t0 = time()
130135
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=parallel + 1, chunksize=1)
131136
finally:
@@ -149,17 +154,17 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
149154
token_t -= t0
150155
token_t_last = np.max(token_t)
151156

152-
print()
153-
print(f"Benchmark duration: {token_t_last:.2f} s")
154-
print(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
155-
print(f"Total prompt length: {np.sum(prompt_n)} tokens")
156-
print(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
157-
print(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
158-
print(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
159-
print(f"Total generated tokens: {token_t.shape[0]}")
160-
print(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
161-
print(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
162-
print(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
157+
logger.info("")
158+
logger.info(f" Benchmark duration: {token_t_last:.2f} s")
159+
logger.info(f" Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
160+
logger.info(f" Total prompt length: {np.sum(prompt_n)} tokens")
161+
logger.info(f" Average prompt length: {np.mean(prompt_n):.2f} tokens")
162+
logger.info(f" Average prompt latency: {np.mean(prompt_ms):.2f} ms")
163+
logger.info(f" Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
164+
logger.info(f" Total generated tokens: {token_t.shape[0]}")
165+
logger.info(f" Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
166+
logger.info(f" Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
167+
logger.info(f" Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
163168

164169
plt.figure()
165170
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)

0 commit comments

Comments
 (0)