Skip to content

Commit

Permalink
Add prefix-cache dataset in mlc bench (#3065)
Browse files Browse the repository at this point in the history
Add Loogle and React dataset. Add fake warmup option to avoid initalizing prefix cache.
  • Loading branch information
jinhongyii authored Dec 15, 2024
1 parent 385cef2 commit e5f3880
Show file tree
Hide file tree
Showing 5 changed files with 417 additions and 23 deletions.
15 changes: 11 additions & 4 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ def run_pipeline(
args.output_len_std,
)
request_records = pipeline(request_records)
assert len(request_records) == args.num_requests
sorted_requests: List[RequestRecord] = [None] * args.num_requests
assert len(request_records) == args.num_requests * args.num_gpus
sorted_requests: List[RequestRecord] = [None] * args.num_requests * args.num_gpus
for request_record in request_records:
assert request_record.request_id is not None
assert sorted_requests[request_record.request_id] is None
sorted_requests[request_record.request_id] = request_record

request_records = MetricAnalyzer(tokenizer)(request_records)
report = generate_metrics_summary(request_records, args.num_requests, args.num_gpus)
report = generate_metrics_summary(
request_records, args.num_requests * args.num_gpus, args.num_gpus
)
return report, sorted_requests


Expand All @@ -135,7 +137,7 @@ def _main():
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
dataset = create_dataset(args, tokenizer)
f_create_api_endpoint = functools.partial(create_api_endpoint, args)
pipelines = create_pipelines(args, f_create_api_endpoint)
pipelines = create_pipelines(args, f_create_api_endpoint, dataset)
reports = []
alltime_records = {}
for i, pipeline in enumerate(pipelines):
Expand Down Expand Up @@ -291,6 +293,7 @@ def _main():
parser.add_argument(
"--timeout",
type=float,
default=3 * 60 * 60,
help="The timeout limit of each request.",
)
parser.add_argument(
Expand Down Expand Up @@ -380,4 +383,8 @@ def _main():
"The --num-concurrent-requests should be provided when enabling this option.",
)

parser.add_argument(
"--testset-name", type=str, help="The name of the testset. Only used for Loogle dataset"
)

main(parser.parse_args())
6 changes: 4 additions & 2 deletions python/mlc_llm/bench/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__( # pylint: disable=too-many-arguments
async def __aenter__(self) -> Self:
import aiohttp # pylint: disable=import-outside-toplevel,import-error

self.client = aiohttp.ClientSession()
self.client = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout))
return self

async def __aexit__(self, exc_type, exc_value, tb) -> None:
Expand Down Expand Up @@ -249,7 +249,9 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
start_time = time.monotonic()

try:
async with self.client.post(self.url, json=payload, headers=self.headers) as response:
async with self.client.post(
self.url, json=payload, headers=self.headers, timeout=3600
) as response:
assert response.status == 200, await response.text()
if payload["stream"]:
async for chunk in response.content:
Expand Down
Loading

0 comments on commit e5f3880

Please sign in to comment.