Skip to content

Commit 03509ce

Browse files
authored
[Bench] Add WildChat dataset (#3102)
This PR introduces [WildChat dataset](https://huggingface.co/datasets/allenai/WildChat) to mlc_llm bench.
1 parent c84e80a commit 03509ce

File tree

2 files changed

+123
-186
lines changed

2 files changed

+123
-186
lines changed

python/mlc_llm/bench/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def _main():
187187
parser.add_argument(
188188
"--dataset-path",
189189
type=str,
190-
required=True,
191190
help="The dataset file path.",
192191
)
193192
parser.add_argument(

python/mlc_llm/bench/dataset.py

Lines changed: 123 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -581,198 +581,134 @@ def generate_request_records(
581581
return request_records
582582

583583

584-
# Todo: dataset of log replay # pylint: disable=fixme
585-
# NOTE: moved from the previous "python/mlc_llm/bench/prompts.py"
586-
# class PromptsGenerator: # pylint: disable=too-few-public-methods
587-
# """
588-
# Generates prompts of a specified token length from a text file containing potential prompts.
589-
# """
590-
591-
# def __init__(
592-
# self,
593-
# prompts_path: Optional[str] = None,
594-
# json_prompts_path: Optional[str] = None,
595-
# tokenizer: Optional[Any] = None,
596-
# seed: Optional[int] = 11111,
597-
# ) -> None:
598-
# """
599-
# Initializes the PromptsGenerator with the file path and tokenizer.
600-
601-
# Parameters
602-
# ----------
603-
# prompts_path : Optional[str]
604-
# The path to the file containing the source prompts. This file can be
605-
# a plain text file, with each line representing a separate prompt str,
606-
# or a .jsonl file where each line is a JSON object formatted as
607-
# {"prompt": "prompt text", "input_tokens": 10}.
608-
609-
# json_prompts_path : Optional[str]
610-
# The path to the file containing the source json prompts. This file a
611-
# .jsonl file where each line is a JSON object formatted as
612-
# {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
613-
614-
# tokenizer : Optional[Any]
615-
# The tokenizer object to use for tokenizing the prompts.
616-
617-
# seed : Optional[int]
618-
# The seed for the random number generator.
619-
# """
620-
# random.seed(seed)
621-
# self.tokenizer = tokenizer
622-
# if not self.tokenizer:
623-
# from transformers import ( # pylint: disable=import-outside-toplevel,import-error
624-
# LlamaTokenizerFast,
625-
# )
626-
627-
# self.tokenizer = LlamaTokenizerFast.from_pretrained(
628-
# "hf-internal-testing/llama-tokenizer"
629-
# )
630-
# logger.warning("No tokenizer provided. Using default tokenizer.")
631-
632-
# self.prompts: List[Dict] = []
633-
# if prompts_path is not None and prompts_path.endswith(".jsonl"):
634-
# with open(prompts_path, "r", encoding="utf-8") as file:
635-
# for line in file:
636-
# json_line = json.loads(line)
637-
# assert "prompt" in json_line, "The prompt field is required in the JSONL file"
638-
# if "input_tokens" not in json_line:
639-
# json_line["input_tokens"] = self._count_tokens(json_line["prompt"])
640-
# self.prompts.append(json_line)
641-
# else:
642-
# if not prompts_path:
643-
# prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore
644-
# with open(prompts_path, "r", encoding="utf-8") as file:
645-
# prompt_line = file.readline()
646-
# input_tokens = self._count_tokens(prompt_line)
647-
# self.prompts.append({"prompt": prompt_line, "input_tokens": input_tokens})
648-
# if json_prompts_path:
649-
# self.json_prompts = defaultdict(list)
650-
# with open(json_prompts_path, "r", encoding="utf-8") as file:
651-
# for line in file:
652-
# json_line = json.loads(line)
653-
# assert (
654-
# "messages" in json_line
655-
# ), "The messages field is required in the JSONL file."
656-
# assert (
657-
# "response_format" in json_line
658-
# ), "The response_format field is required in the JSONL file."
659-
# self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
660-
# json_line["messages"]
661-
# )
662-
# else:
663-
# self.json_prompts = None
664-
665-
# def _count_tokens(self, text: str) -> int:
666-
# """Get the number of tokens.
667-
668-
# Parameters
669-
# ----------
670-
# text : str
671-
# The text to tokenize.
672-
673-
# Returns
674-
# -------
675-
# output : int
676-
# The number of tokens
677-
# """
678-
# return len(self.tokenizer.encode(text))
679-
680-
# def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
681-
# """
682-
# Generates a prompt based on the params, e.g. input_tokens, response_format.
683-
684-
# Parameters
685-
# ----------
686-
# params : Dict[str, Any]
687-
# The desired mean number of tokens in the prompt.
688-
689-
# Returns
690-
# -------
691-
# override_params: Dict[str, Any]
692-
# The params to override the original request, e.g. messages, response_format.
693-
# """
694-
# if "response_format" in params:
695-
# response_format = params["response_format"]
696-
# if response_format.get("type") == "json_object":
697-
# if response_format.get("schema") in self.json_prompts:
698-
# assert len(self.json_prompts[response_format["schema"]]) > 0
699-
# return {"messages":
700-
# random.choice(self.json_prompts[response_format["schema"]])}
701-
# schema, prompts = random.choice(list(self.json_prompts.items()))
702-
# response_format["schema"] = schema
703-
# return {"messages": random.choice(prompts), "response_format": response_format}
704-
# tokens_mean = params.get("input_tokens", 128)
705-
# assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
706-
# remaining_input_tokens = tokens_mean
707-
# result_prompt = ""
708-
# override_params = None
709-
# while remaining_input_tokens > 0:
710-
# prompt_dict = random.choice(self.prompts)
711-
# cur_input_tokens = prompt_dict["input_tokens"]
712-
# cur_prompt = prompt_dict["prompt"]
713-
# if override_params is None:
714-
# override_params = prompt_dict["override_params"]
715-
# if remaining_input_tokens - cur_input_tokens < 0:
716-
# result_prompt += cur_prompt[:remaining_input_tokens]
717-
# remaining_input_tokens = 0
718-
# break
719-
# result_prompt += cur_prompt
720-
# remaining_input_tokens -= cur_input_tokens
721-
# return {"messages": [{"role": "system", "content": result_prompt}]}
722-
723-
724-
# def load_replay_log(log_path: str) -> List[Dict]:
725-
# """
726-
# Load replay log from file
727-
728-
# Parameters
729-
# ----------
730-
# log_path : str
731-
# The path to the event log CSV or JSONL file containing the events to replay.
732-
733-
# Returns
734-
# -------
735-
# res: List[Dict]
736-
# A list of preprocessed event data for replay.
737-
# """
738-
# if log_path.endswith(".csv"):
739-
# import pandas as pd # pylint: disable=import-outside-toplevel,import-error
740-
741-
# df = pd.read_csv(log_path)
742-
# column_names = df.columns.values
743-
# assert (
744-
# ("Date" in column_names)
745-
# and ("@request" in column_names)
746-
# and ("Message" in column_names)
747-
# )
748-
# df["timestamp"] = pd.to_datetime(df["Date"])
749-
# df.sort_values("timestamp", inplace=True)
750-
# # Get the request params from the loaded CSV
751-
# params = []
752-
# for _, row in df.iterrows():
753-
# request = row["@request"]
754-
# payload = json.loads(str(request))
755-
# params.append(
756-
# {
757-
# "timestamp": row["timestamp"],
758-
# "payload": payload,
759-
# }
760-
# )
761-
# return params
762-
# if log_path.endswith(".jsonl"):
763-
# with open(log_path, "r", encoding="utf-8") as file:
764-
# data = [json.loads(line) for line in file]
765-
# for row in data:
766-
# row["timestamp"] = datetime.fromisoformat(str(row["timestamp"]))
767-
# return data
768-
# raise ValueError("Unsupported file format. Please use .csv or .jsonl.")
584+
class WildChatDataset(Dataset): # pylint: disable=too-few-public-methods
585+
"""The dataset class for WildChat dataset."""
586+
587+
apply_chat_template: bool
588+
589+
def __init__(self, tokenizer: AutoTokenizer, apply_chat_template: bool) -> None:
590+
raw_dataset = load_dataset("allenai/WildChat", split="train")
591+
self.tokenizer = tokenizer
592+
self.apply_chat_template = apply_chat_template
593+
594+
# Filter out the conversations with less than 2 turns.
595+
_dataset = [
596+
(entry["conversation"][0]["content"], entry["conversation"][1]["content"])
597+
for entry in raw_dataset
598+
if len(entry["conversation"]) >= 2
599+
and entry["conversation"][0]["role"] == "user"
600+
and entry["conversation"][1]["role"] == "assistant"
601+
]
602+
603+
prompts = []
604+
completions = []
605+
for prompt, completion in _dataset:
606+
prompts.append(prompt)
607+
completions.append(completion)
608+
if apply_chat_template:
609+
assert (
610+
getattr(tokenizer, "chat_template", None) is not None
611+
), '"--apply-chat-template" is set but the tokenizer does not have chat template.'
612+
prompts = [
613+
tokenizer.apply_chat_template(
614+
[{"role": "user", "content": prompt}],
615+
add_generation_prompt=True,
616+
tokenize=False,
617+
)
618+
for prompt in prompts
619+
]
620+
621+
prompt_token_ids = list(
622+
tokenizer(
623+
prompts,
624+
truncation=True,
625+
max_length=min(tokenizer.model_max_length, self.truncate_length),
626+
add_special_tokens=False,
627+
).input_ids
628+
)
629+
completion_token_ids = tokenizer(
630+
completions,
631+
truncation=True,
632+
max_length=min(tokenizer.model_max_length, self.truncate_length),
633+
add_special_tokens=False,
634+
).input_ids
635+
self._tokenized_dataset: List[Tuple[str, List[int], int]] = []
636+
for i in range(len(_dataset)):
637+
if len(prompt_token_ids[i]) < 4 or len(completion_token_ids[i]) < 4:
638+
# Filter out sequences that are too short
639+
continue
640+
self._tokenized_dataset.append(
641+
(prompts[i], prompt_token_ids[i], len(completion_token_ids[i]))
642+
)
643+
644+
def generate_request_records( # pylint: disable=too-many-locals
645+
self,
646+
input_len: Optional[int],
647+
output_len: Optional[int],
648+
input_len_std: float = 0.0,
649+
output_len_std: float = 0.0,
650+
) -> List[RequestRecord]:
651+
if self.apply_chat_template:
652+
assert (
653+
input_len is None
654+
), '"--apply-chat-template" is not supported when "--input-len" is specified.'
655+
656+
request_records = []
657+
for prompt, input_token_ids, output_length in self._tokenized_dataset:
658+
input_length = len(input_token_ids)
659+
# If the request does not have enough length, discard it.
660+
if input_len is not None and input_length < input_len + 4 * input_len_std:
661+
continue
662+
663+
if input_len is not None:
664+
input_length = round(
665+
float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])
666+
)
667+
input_token_ids = input_token_ids[:input_length]
668+
input_truncated = True
669+
else:
670+
input_truncated = False
671+
if output_len is not None:
672+
output_length = round(
673+
float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])
674+
)
675+
elif output_length <= 1:
676+
continue
677+
request_records.append(
678+
RequestRecord(
679+
chat_cmpl=ChatCompletionRequest(
680+
messages=[
681+
{
682+
"role": "user",
683+
"content": (
684+
self.tokenizer.decode(input_token_ids)
685+
if input_truncated
686+
else prompt
687+
),
688+
}
689+
],
690+
model="",
691+
max_tokens=output_length,
692+
),
693+
metrics=Metrics(
694+
success=False,
695+
start_time=0,
696+
finish_time=0,
697+
end_to_end_latency_s=0,
698+
input_tokens=len(input_token_ids),
699+
),
700+
)
701+
)
702+
return request_records
703+
769704

770705
SUPPORTED_DATASET = [
771706
"sharegpt",
772707
"llmperf",
773708
"json-mode-eval",
774709
"loogle",
775710
"react",
711+
"wildchat",
776712
]
777713

778714

@@ -811,4 +747,6 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
811747
args.apply_chat_template is False
812748
), "ReAct dataset does not support applying chat template"
813749
return ReActDataset(args.dataset_path, tokenizer)
750+
if args.dataset == "wildchat":
751+
return WildChatDataset(tokenizer, args.apply_chat_template)
814752
raise ValueError(f"Unrecognized dataset {args.dataset}")

0 commit comments

Comments
 (0)