Skip to content

Commit

Permalink
[Bench] Add WildChat dataset (#3102)
Browse files Browse the repository at this point in the history
This PR introduces [WildChat dataset](https://huggingface.co/datasets/allenai/WildChat)
to mlc_llm bench.
  • Loading branch information
MasterJH5574 authored Jan 21, 2025
1 parent c84e80a commit 03509ce
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 186 deletions.
1 change: 0 additions & 1 deletion python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def _main():
parser.add_argument(
"--dataset-path",
type=str,
required=True,
help="The dataset file path.",
)
parser.add_argument(
Expand Down
308 changes: 123 additions & 185 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,198 +581,134 @@ def generate_request_records(
return request_records


# Todo: dataset of log replay # pylint: disable=fixme
# NOTE: moved from the previous "python/mlc_llm/bench/prompts.py"
# class PromptsGenerator: # pylint: disable=too-few-public-methods
# """
# Generates prompts of a specified token length from a text file containing potential prompts.
# """

# def __init__(
# self,
# prompts_path: Optional[str] = None,
# json_prompts_path: Optional[str] = None,
# tokenizer: Optional[Any] = None,
# seed: Optional[int] = 11111,
# ) -> None:
# """
# Initializes the PromptsGenerator with the file path and tokenizer.

# Parameters
# ----------
# prompts_path : Optional[str]
# The path to the file containing the source prompts. This file can be
# a plain text file, with each line representing a separate prompt str,
# or a .jsonl file where each line is a JSON object formatted as
# {"prompt": "prompt text", "input_tokens": 10}.

# json_prompts_path : Optional[str]
# The path to the file containing the source json prompts. This file a
# .jsonl file where each line is a JSON object formatted as
# {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.

# tokenizer : Optional[Any]
# The tokenizer object to use for tokenizing the prompts.

# seed : Optional[int]
# The seed for the random number generator.
# """
# random.seed(seed)
# self.tokenizer = tokenizer
# if not self.tokenizer:
# from transformers import ( # pylint: disable=import-outside-toplevel,import-error
# LlamaTokenizerFast,
# )

# self.tokenizer = LlamaTokenizerFast.from_pretrained(
# "hf-internal-testing/llama-tokenizer"
# )
# logger.warning("No tokenizer provided. Using default tokenizer.")

# self.prompts: List[Dict] = []
# if prompts_path is not None and prompts_path.endswith(".jsonl"):
# with open(prompts_path, "r", encoding="utf-8") as file:
# for line in file:
# json_line = json.loads(line)
# assert "prompt" in json_line, "The prompt field is required in the JSONL file"
# if "input_tokens" not in json_line:
# json_line["input_tokens"] = self._count_tokens(json_line["prompt"])
# self.prompts.append(json_line)
# else:
# if not prompts_path:
# prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore
# with open(prompts_path, "r", encoding="utf-8") as file:
# prompt_line = file.readline()
# input_tokens = self._count_tokens(prompt_line)
# self.prompts.append({"prompt": prompt_line, "input_tokens": input_tokens})
# if json_prompts_path:
# self.json_prompts = defaultdict(list)
# with open(json_prompts_path, "r", encoding="utf-8") as file:
# for line in file:
# json_line = json.loads(line)
# assert (
# "messages" in json_line
# ), "The messages field is required in the JSONL file."
# assert (
# "response_format" in json_line
# ), "The response_format field is required in the JSONL file."
# self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
# json_line["messages"]
# )
# else:
# self.json_prompts = None

# def _count_tokens(self, text: str) -> int:
# """Get the number of tokens.

# Parameters
# ----------
# text : str
# The text to tokenize.

# Returns
# -------
# output : int
# The number of tokens
# """
# return len(self.tokenizer.encode(text))

# def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
# """
# Generates a prompt based on the params, e.g. input_tokens, response_format.

# Parameters
# ----------
# params : Dict[str, Any]
# The desired mean number of tokens in the prompt.

# Returns
# -------
# override_params: Dict[str, Any]
# The params to override the original request, e.g. messages, response_format.
# """
# if "response_format" in params:
# response_format = params["response_format"]
# if response_format.get("type") == "json_object":
# if response_format.get("schema") in self.json_prompts:
# assert len(self.json_prompts[response_format["schema"]]) > 0
# return {"messages":
# random.choice(self.json_prompts[response_format["schema"]])}
# schema, prompts = random.choice(list(self.json_prompts.items()))
# response_format["schema"] = schema
# return {"messages": random.choice(prompts), "response_format": response_format}
# tokens_mean = params.get("input_tokens", 128)
# assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
# remaining_input_tokens = tokens_mean
# result_prompt = ""
# override_params = None
# while remaining_input_tokens > 0:
# prompt_dict = random.choice(self.prompts)
# cur_input_tokens = prompt_dict["input_tokens"]
# cur_prompt = prompt_dict["prompt"]
# if override_params is None:
# override_params = prompt_dict["override_params"]
# if remaining_input_tokens - cur_input_tokens < 0:
# result_prompt += cur_prompt[:remaining_input_tokens]
# remaining_input_tokens = 0
# break
# result_prompt += cur_prompt
# remaining_input_tokens -= cur_input_tokens
# return {"messages": [{"role": "system", "content": result_prompt}]}


# def load_replay_log(log_path: str) -> List[Dict]:
# """
# Load replay log from file

# Parameters
# ----------
# log_path : str
# The path to the event log CSV or JSONL file containing the events to replay.

# Returns
# -------
# res: List[Dict]
# A list of preprocessed event data for replay.
# """
# if log_path.endswith(".csv"):
# import pandas as pd # pylint: disable=import-outside-toplevel,import-error

# df = pd.read_csv(log_path)
# column_names = df.columns.values
# assert (
# ("Date" in column_names)
# and ("@request" in column_names)
# and ("Message" in column_names)
# )
# df["timestamp"] = pd.to_datetime(df["Date"])
# df.sort_values("timestamp", inplace=True)
# # Get the request params from the loaded CSV
# params = []
# for _, row in df.iterrows():
# request = row["@request"]
# payload = json.loads(str(request))
# params.append(
# {
# "timestamp": row["timestamp"],
# "payload": payload,
# }
# )
# return params
# if log_path.endswith(".jsonl"):
# with open(log_path, "r", encoding="utf-8") as file:
# data = [json.loads(line) for line in file]
# for row in data:
# row["timestamp"] = datetime.fromisoformat(str(row["timestamp"]))
# return data
# raise ValueError("Unsupported file format. Please use .csv or .jsonl.")
class WildChatDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for WildChat dataset."""

apply_chat_template: bool

def __init__(self, tokenizer: AutoTokenizer, apply_chat_template: bool) -> None:
raw_dataset = load_dataset("allenai/WildChat", split="train")
self.tokenizer = tokenizer
self.apply_chat_template = apply_chat_template

# Filter out the conversations with less than 2 turns.
_dataset = [
(entry["conversation"][0]["content"], entry["conversation"][1]["content"])
for entry in raw_dataset
if len(entry["conversation"]) >= 2
and entry["conversation"][0]["role"] == "user"
and entry["conversation"][1]["role"] == "assistant"
]

prompts = []
completions = []
for prompt, completion in _dataset:
prompts.append(prompt)
completions.append(completion)
if apply_chat_template:
assert (
getattr(tokenizer, "chat_template", None) is not None
), '"--apply-chat-template" is set but the tokenizer does not have chat template.'
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
for prompt in prompts
]

prompt_token_ids = list(
tokenizer(
prompts,
truncation=True,
max_length=min(tokenizer.model_max_length, self.truncate_length),
add_special_tokens=False,
).input_ids
)
completion_token_ids = tokenizer(
completions,
truncation=True,
max_length=min(tokenizer.model_max_length, self.truncate_length),
add_special_tokens=False,
).input_ids
self._tokenized_dataset: List[Tuple[str, List[int], int]] = []
for i in range(len(_dataset)):
if len(prompt_token_ids[i]) < 4 or len(completion_token_ids[i]) < 4:
# Filter out sequences that are too short
continue
self._tokenized_dataset.append(
(prompts[i], prompt_token_ids[i], len(completion_token_ids[i]))
)

def generate_request_records( # pylint: disable=too-many-locals
self,
input_len: Optional[int],
output_len: Optional[int],
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
if self.apply_chat_template:
assert (
input_len is None
), '"--apply-chat-template" is not supported when "--input-len" is specified.'

request_records = []
for prompt, input_token_ids, output_length in self._tokenized_dataset:
input_length = len(input_token_ids)
# If the request does not have enough length, discard it.
if input_len is not None and input_length < input_len + 4 * input_len_std:
continue

if input_len is not None:
input_length = round(
float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])
)
input_token_ids = input_token_ids[:input_length]
input_truncated = True
else:
input_truncated = False
if output_len is not None:
output_length = round(
float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])
)
elif output_length <= 1:
continue
request_records.append(
RequestRecord(
chat_cmpl=ChatCompletionRequest(
messages=[
{
"role": "user",
"content": (
self.tokenizer.decode(input_token_ids)
if input_truncated
else prompt
),
}
],
model="",
max_tokens=output_length,
),
metrics=Metrics(
success=False,
start_time=0,
finish_time=0,
end_to_end_latency_s=0,
input_tokens=len(input_token_ids),
),
)
)
return request_records


SUPPORTED_DATASET = [
"sharegpt",
"llmperf",
"json-mode-eval",
"loogle",
"react",
"wildchat",
]


Expand Down Expand Up @@ -811,4 +747,6 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
args.apply_chat_template is False
), "ReAct dataset does not support applying chat template"
return ReActDataset(args.dataset_path, tokenizer)
if args.dataset == "wildchat":
return WildChatDataset(tokenizer, args.apply_chat_template)
raise ValueError(f"Unrecognized dataset {args.dataset}")

0 comments on commit 03509ce

Please sign in to comment.