Skip to content

Commit

Permalink
[Bench] Add AzureLLMInference dataset (#3104)
Browse files Browse the repository at this point in the history
This PR introduces the [AzureLLMInference dataset](https://github.com/Azure/AzurePublicDataset).
This dataset contains the timestamp for each entries, and this PR
also introduces the dataset replay mode for mlc-llm benchmark.
This mode reuses the provided timestamps for benchmark.
  • Loading branch information
MasterJH5574 authored Jan 22, 2025
1 parent 03509ce commit 2c1001b
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 18 deletions.
8 changes: 8 additions & 0 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def _main():
"When specified, the benchmark sends these many new requests each second. "
'If it is "inf", all requests will be sent together at once.',
)
parser.add_argument(
"--replay-timestamp-scale",
type=float,
help="The timestamp scale when replaying the timestamps in a dataset. "
'The dataset replay mode is enabled when neither "--num-concurrent-requests" and '
'"--request-rate" is specified. '
"The scale is 1 by default in the replay mode.",
)
parser.add_argument(
"--input-len",
type=int,
Expand Down
128 changes: 126 additions & 2 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import argparse
import json
import random
from datetime import datetime
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd # pylint: disable=import-error
from datasets import load_dataset # pylint: disable=import-error
from transformers import AutoTokenizer # pylint: disable=import-error

Expand All @@ -25,6 +27,10 @@ class Dataset: # pylint: disable=too-few-public-methods
# For some that datasets (e.g., dataset that has shared common prefix),
# we need fake warmup requests to avoid prefilling common prefixes to the engine.
require_fake_warmup: bool = False
# Whether the dataset contains timestamps already.
# If the dataset comes with timestamps, the benchmark can just replay
# the requests according to their timestamps.
timestamp_available: bool = False

def generate_request_records(
self,
Expand Down Expand Up @@ -702,19 +708,111 @@ def generate_request_records( # pylint: disable=too-many-locals
return request_records


class AzureLLMInferenceDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for AzureLLMInference dataset.
Reference: https://github.com/Azure/AzurePublicDataset
"""

timestamp_available: bool = True

def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
df = pd.read_csv(dataset_path)
self.tokenizer = tokenizer

# Filter out the conversations with less than 2 turns.
self.dataset = [
(
entry["TIMESTAMP"],
min(entry["ContextTokens"], tokenizer.model_max_length, self.truncate_length),
min(entry["GeneratedTokens"], tokenizer.model_max_length, self.truncate_length),
)
for _, entry in df.iterrows()
if entry["ContextTokens"] >= 4 and entry["GeneratedTokens"] >= 4
]

def generate_request_records( # pylint: disable=too-many-locals
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
time_fmt = "%Y-%m-%d %H:%M:%S.%f"
start_time = datetime.strptime(self.dataset[0][0][:-1], time_fmt)
request_records = []
for timestamp, input_length, output_length in self.dataset:
# If the request does not have enough length, discard it.
if input_len is not None and input_length < input_len + 4 * input_len_std:
continue

if input_len is not None:
input_length = round(
float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])
)
if output_len is not None:
output_length = round(
float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])
)
elif output_length <= 1:
continue

prompt_token_ids = [
random.randint(0, self.tokenizer.vocab_size - 1) for _ in range(input_length)
]
while True:
# Adjust the token ids until the retokenization on the decoded string
# matches the required input length.
prompt = self.tokenizer.decode(prompt_token_ids)
retokenized_token_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
if len(retokenized_token_ids) < input_length:
prompt_token_ids = retokenized_token_ids + [
random.randint(0, self.tokenizer.vocab_size - 1)
for _ in range(input_length - len(retokenized_token_ids))
]
elif len(retokenized_token_ids) > input_length:
prompt_token_ids = retokenized_token_ids[:input_length]
else:
break

time_diff = (datetime.strptime(timestamp[:-1], time_fmt) - start_time).total_seconds()
request_records.append(
RequestRecord(
chat_cmpl=ChatCompletionRequest(
messages=[{"role": "user", "content": prompt}],
model="",
max_tokens=output_length,
),
timestamp=time_diff,
metrics=Metrics(
success=False,
start_time=0,
finish_time=0,
end_to_end_latency_s=0,
input_tokens=input_length,
),
)
)
return request_records


SUPPORTED_DATASET = [
"sharegpt",
"llmperf",
"json-mode-eval",
"loogle",
"react",
"wildchat",
"azure-llm-inference",
]


def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Dataset":
def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches
args: argparse.Namespace, tokenizer: AutoTokenizer
) -> Dataset:
"""Create a dataset instance with regard to the specified dataset kind and file path."""
if args.dataset is None:
if args.dataset_path is not None and not isinstance(args.dataset_path, str):
raise TypeError(f"Invalid dataset path {args.dataset_path}. Please use a string.")
if args.dataset is None and args.dataset_path is not None:
# Auto-detect the dataset kind by looking into the dataset path.
if "sharegpt" in args.dataset_path.lower():
args.dataset = "sharegpt"
Expand All @@ -724,8 +822,16 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
'Please specify the dataset kind via "--dataset".'
)
if args.dataset == "sharegpt":
if args.dataset_path is None:
raise ValueError(
'ShareGPT dataset requires dataset path. Please specify it with "--dataset-path".'
)
return ShareGPTDataset(args.dataset_path, tokenizer, args.apply_chat_template)
if args.dataset == "llmperf":
if args.dataset_path is None:
raise ValueError(
'LLMPerf dataset requires dataset path. Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "LLMPerf dataset does not support applying chat template"
Expand All @@ -738,15 +844,33 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
), "JSON mode evaluation does not support applying chat template"
return JSONModeEvalDataset(tokenizer)
if args.dataset == "loogle":
if args.dataset_path is None:
raise ValueError(
'Loogle dataset requires a testset name. Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "Loogle dataset does not support applying chat template"
return LoogleDataset(tokenizer, testset_name=args.dataset_path)
if args.dataset == "react":
if args.dataset_path is None:
raise ValueError(
'ReAct dataset requires dataset path. Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "ReAct dataset does not support applying chat template"
return ReActDataset(args.dataset_path, tokenizer)
if args.dataset == "wildchat":
return WildChatDataset(tokenizer, args.apply_chat_template)
if args.dataset == "azure-llm-inference":
if args.dataset_path is None:
raise ValueError(
"AzureLLMInference dataset requires dataset path. "
'Please specify it with "--dataset-path".'
)
assert (
args.apply_chat_template is False
), "AzureLLMInference dataset does not support applying chat template"
return AzureLLMInferenceDataset(args.dataset_path, tokenizer)
raise ValueError(f"Unrecognized dataset {args.dataset}")
104 changes: 88 additions & 16 deletions python/mlc_llm/bench/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
class SampleRequests(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that samples requests out from the given request list."""

def __init__(self, num_requests: int) -> None:
def __init__(self, num_requests: int, take_first_x_requests: bool = False) -> None:
self.num_requests = num_requests
# If `take_first_x_requests` is True, the first `num_requests` requests
# are returned and sampling will not happen.
self.take_first_x_requests = take_first_x_requests

def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
assert len(request_records) > 0, "Empty input request record."
Expand All @@ -69,12 +72,20 @@ def _sample_from_plain_request_records(
self, request_records: List[RequestRecord]
) -> List[RequestRecord]:
samples: List[RequestRecord] = []
while len(samples) < self.num_requests:
# Create a new list so that the in-place shuffle does not mutate the input list.
records = list(request_records)
random.shuffle(records)
samples += copy.deepcopy(records)
samples = samples[: self.num_requests]
if self.take_first_x_requests:
if len(request_records) < self.num_requests:
raise ValueError(
f"Insufficient requests. Requiring {self.num_requests} requests "
f"but only {len(request_records)} are available."
)
samples = copy.deepcopy(list(request_records[: self.num_requests]))
else:
while len(samples) < self.num_requests:
# Create a new list so that the in-place shuffle does not mutate the input list.
records = list(request_records)
random.shuffle(records)
samples += copy.deepcopy(records)
samples = samples[: self.num_requests]
for i, record in enumerate(samples):
record.request_id = i
return samples
Expand All @@ -95,7 +106,8 @@ def _sample_from_grouped_request_records(

# Create a new list so that the in-place shuffle does not mutate the input list.
records = list(grouped_request_records)
random.shuffle(records)
if not self.take_first_x_requests:
random.shuffle(records)
remaining = self.num_requests
samples: List[RequestRecord] = []
for grouped_request_record in grouped_request_records:
Expand Down Expand Up @@ -183,6 +195,22 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
return request_records


class ScaleTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods
"""Scale the timestamp of requests by the given scale factor."""

def __init__(self, timestamp_scale: float):
self.timestamp_scale = timestamp_scale

def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
for request_record in request_records:
if request_record.timestamp is None:
raise ValueError(
f"The timestamp of request {request_record} has not been initialized."
)
request_record.timestamp *= self.timestamp_scale
return request_records


class MetricAnalyzer(RequestProcessor): # pylint: disable=too-few-public-methods
"""The processor that analyzes the raw benchmark results and computes more detailed metrics."""

Expand Down Expand Up @@ -463,7 +491,6 @@ def __init__( # pylint: disable=too-many-arguments
disable_tqdm: bool,
max_schedule_gap: float,
num_requests: int,
request_rate: Optional[np.float32] = None,
) -> None:
if num_processes is None:
# We assign each process at most 32 requests to send
Expand All @@ -472,7 +499,6 @@ def __init__( # pylint: disable=too-many-arguments
super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)
self.max_schedule_gap = max_schedule_gap
self.num_requests = num_requests
self.request_rate = request_rate

def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
assert len(request_records) > 0
Expand Down Expand Up @@ -574,7 +600,7 @@ async def _task(request_record: RequestRecord) -> None:
)


def create_pipelines(
def create_pipelines( # pylint: disable=too-many-branches
args: argparse.Namespace, f_create_api_endpoint: Callable[[], APIEndPoint], dataset: Dataset
) -> List[RequestProcessor]:
"""Creating request processing pipelines with regard to the specified args."""
Expand All @@ -586,6 +612,10 @@ def create_pipelines(
'Both "num_concurrent_requests" and "request_rate" are specified. '
"Please specify only one of them."
)
if args.replay_timestamp_scale is not None:
raise ValueError(
"Dataset replay is unsupported when fixing number of concurrent requests."
)
for num_concurrent_requests in args.num_concurrent_requests:
num_warmup_requests = (
args.num_warmup_requests
Expand Down Expand Up @@ -622,6 +652,8 @@ def create_pipelines(
"Please specify the number of warmup requests via "
'"--num-warmup-requests" when fixing request rate.'
)
if args.replay_timestamp_scale is not None:
raise ValueError("Dataset replay is unsupported when fixing request rates.")
num_total_requests = int(
args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus
)
Expand Down Expand Up @@ -649,15 +681,55 @@ def create_pipelines(
args.disable_tqdm,
args.max_schedule_gap,
args.num_requests,
request_rate,
),
cuda_profile_url=cuda_profile_url,
fake_warmup=dataset.require_fake_warmup,
),
)
for request_rate in args.request_rate
]
raise ValueError(
'Unable to create executor. Please specify one of "num_concurrent_requests" '
'and "request_rate".'
)

# Default: dataset replay mode
# The dataset must come with timestamps.
if not dataset.timestamp_available:
raise ValueError(
"The dataset does not have timestamps, so dataset replay is unsupported. "
'Please specify one of "num_concurrent_requests" '
'and "request_rate".'
)
if args.per_gpu_workload:
raise ValueError("Fixing per-GPU workload is not compatible with dataset replay.")
if args.num_warmup_requests is None:
raise ValueError(
"Please specify the number of warmup requests via "
'"--num-warmup-requests" for dataset replay.'
)
timestamp_scale = args.replay_timestamp_scale or 1.0
if dataset.require_fake_warmup:
num_samples = args.num_requests
else:
num_samples = args.num_requests + args.num_warmup_requests
return [
SequentialProcessor(
LogMessage(f"Dataset replay with time scaling of {timestamp_scale}"),
SampleRequests(num_samples, take_first_x_requests=True),
AttachModelName(args.tokenizer),
ScaleTimestamp(timestamp_scale),
AttachStreamFlag(args.stream),
AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),
AttachExecutionFeature({"timestamp_scale": timestamp_scale}),
WarmupAndRun(
num_warmup_requests=args.num_warmup_requests,
num_benchmark_requests=args.num_requests,
pipeline=FixTimestampExecutor(
f_create_api_endpoint,
args.num_process_workers,
args.disable_tqdm,
args.max_schedule_gap,
args.num_requests,
),
cuda_profile_url=cuda_profile_url,
fake_warmup=dataset.require_fake_warmup,
),
)
]

0 comments on commit 2c1001b

Please sign in to comment.