@@ -581,198 +581,134 @@ def generate_request_records(
581581 return request_records
582582
583583
584- # Todo: dataset of log replay # pylint: disable=fixme
585- # NOTE: moved from the previous "python/mlc_llm/bench/prompts.py"
586- # class PromptsGenerator: # pylint: disable=too-few-public-methods
587- # """
588- # Generates prompts of a specified token length from a text file containing potential prompts.
589- # """
590-
591- # def __init__(
592- # self,
593- # prompts_path: Optional[str] = None,
594- # json_prompts_path: Optional[str] = None,
595- # tokenizer: Optional[Any] = None,
596- # seed: Optional[int] = 11111,
597- # ) -> None:
598- # """
599- # Initializes the PromptsGenerator with the file path and tokenizer.
600-
601- # Parameters
602- # ----------
603- # prompts_path : Optional[str]
604- # The path to the file containing the source prompts. This file can be
605- # a plain text file, with each line representing a separate prompt str,
606- # or a .jsonl file where each line is a JSON object formatted as
607- # {"prompt": "prompt text", "input_tokens": 10}.
608-
609- # json_prompts_path : Optional[str]
610- # The path to the file containing the source json prompts. This file a
611- # .jsonl file where each line is a JSON object formatted as
612- # {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
613-
614- # tokenizer : Optional[Any]
615- # The tokenizer object to use for tokenizing the prompts.
616-
617- # seed : Optional[int]
618- # The seed for the random number generator.
619- # """
620- # random.seed(seed)
621- # self.tokenizer = tokenizer
622- # if not self.tokenizer:
623- # from transformers import ( # pylint: disable=import-outside-toplevel,import-error
624- # LlamaTokenizerFast,
625- # )
626-
627- # self.tokenizer = LlamaTokenizerFast.from_pretrained(
628- # "hf-internal-testing/llama-tokenizer"
629- # )
630- # logger.warning("No tokenizer provided. Using default tokenizer.")
631-
632- # self.prompts: List[Dict] = []
633- # if prompts_path is not None and prompts_path.endswith(".jsonl"):
634- # with open(prompts_path, "r", encoding="utf-8") as file:
635- # for line in file:
636- # json_line = json.loads(line)
637- # assert "prompt" in json_line, "The prompt field is required in the JSONL file"
638- # if "input_tokens" not in json_line:
639- # json_line["input_tokens"] = self._count_tokens(json_line["prompt"])
640- # self.prompts.append(json_line)
641- # else:
642- # if not prompts_path:
643- # prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore
644- # with open(prompts_path, "r", encoding="utf-8") as file:
645- # prompt_line = file.readline()
646- # input_tokens = self._count_tokens(prompt_line)
647- # self.prompts.append({"prompt": prompt_line, "input_tokens": input_tokens})
648- # if json_prompts_path:
649- # self.json_prompts = defaultdict(list)
650- # with open(json_prompts_path, "r", encoding="utf-8") as file:
651- # for line in file:
652- # json_line = json.loads(line)
653- # assert (
654- # "messages" in json_line
655- # ), "The messages field is required in the JSONL file."
656- # assert (
657- # "response_format" in json_line
658- # ), "The response_format field is required in the JSONL file."
659- # self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
660- # json_line["messages"]
661- # )
662- # else:
663- # self.json_prompts = None
664-
665- # def _count_tokens(self, text: str) -> int:
666- # """Get the number of tokens.
667-
668- # Parameters
669- # ----------
670- # text : str
671- # The text to tokenize.
672-
673- # Returns
674- # -------
675- # output : int
676- # The number of tokens
677- # """
678- # return len(self.tokenizer.encode(text))
679-
680- # def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
681- # """
682- # Generates a prompt based on the params, e.g. input_tokens, response_format.
683-
684- # Parameters
685- # ----------
686- # params : Dict[str, Any]
687- # The desired mean number of tokens in the prompt.
688-
689- # Returns
690- # -------
691- # override_params: Dict[str, Any]
692- # The params to override the original request, e.g. messages, response_format.
693- # """
694- # if "response_format" in params:
695- # response_format = params["response_format"]
696- # if response_format.get("type") == "json_object":
697- # if response_format.get("schema") in self.json_prompts:
698- # assert len(self.json_prompts[response_format["schema"]]) > 0
699- # return {"messages":
700- # random.choice(self.json_prompts[response_format["schema"]])}
701- # schema, prompts = random.choice(list(self.json_prompts.items()))
702- # response_format["schema"] = schema
703- # return {"messages": random.choice(prompts), "response_format": response_format}
704- # tokens_mean = params.get("input_tokens", 128)
705- # assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
706- # remaining_input_tokens = tokens_mean
707- # result_prompt = ""
708- # override_params = None
709- # while remaining_input_tokens > 0:
710- # prompt_dict = random.choice(self.prompts)
711- # cur_input_tokens = prompt_dict["input_tokens"]
712- # cur_prompt = prompt_dict["prompt"]
713- # if override_params is None:
714- # override_params = prompt_dict["override_params"]
715- # if remaining_input_tokens - cur_input_tokens < 0:
716- # result_prompt += cur_prompt[:remaining_input_tokens]
717- # remaining_input_tokens = 0
718- # break
719- # result_prompt += cur_prompt
720- # remaining_input_tokens -= cur_input_tokens
721- # return {"messages": [{"role": "system", "content": result_prompt}]}
722-
723-
724- # def load_replay_log(log_path: str) -> List[Dict]:
725- # """
726- # Load replay log from file
727-
728- # Parameters
729- # ----------
730- # log_path : str
731- # The path to the event log CSV or JSONL file containing the events to replay.
732-
733- # Returns
734- # -------
735- # res: List[Dict]
736- # A list of preprocessed event data for replay.
737- # """
738- # if log_path.endswith(".csv"):
739- # import pandas as pd # pylint: disable=import-outside-toplevel,import-error
740-
741- # df = pd.read_csv(log_path)
742- # column_names = df.columns.values
743- # assert (
744- # ("Date" in column_names)
745- # and ("@request" in column_names)
746- # and ("Message" in column_names)
747- # )
748- # df["timestamp"] = pd.to_datetime(df["Date"])
749- # df.sort_values("timestamp", inplace=True)
750- # # Get the request params from the loaded CSV
751- # params = []
752- # for _, row in df.iterrows():
753- # request = row["@request"]
754- # payload = json.loads(str(request))
755- # params.append(
756- # {
757- # "timestamp": row["timestamp"],
758- # "payload": payload,
759- # }
760- # )
761- # return params
762- # if log_path.endswith(".jsonl"):
763- # with open(log_path, "r", encoding="utf-8") as file:
764- # data = [json.loads(line) for line in file]
765- # for row in data:
766- # row["timestamp"] = datetime.fromisoformat(str(row["timestamp"]))
767- # return data
768- # raise ValueError("Unsupported file format. Please use .csv or .jsonl.")
584+ class WildChatDataset (Dataset ): # pylint: disable=too-few-public-methods
585+ """The dataset class for WildChat dataset."""
586+
587+ apply_chat_template : bool
588+
589+ def __init__ (self , tokenizer : AutoTokenizer , apply_chat_template : bool ) -> None :
590+ raw_dataset = load_dataset ("allenai/WildChat" , split = "train" )
591+ self .tokenizer = tokenizer
592+ self .apply_chat_template = apply_chat_template
593+
594+ # Filter out the conversations with less than 2 turns.
595+ _dataset = [
596+ (entry ["conversation" ][0 ]["content" ], entry ["conversation" ][1 ]["content" ])
597+ for entry in raw_dataset
598+ if len (entry ["conversation" ]) >= 2
599+ and entry ["conversation" ][0 ]["role" ] == "user"
600+ and entry ["conversation" ][1 ]["role" ] == "assistant"
601+ ]
602+
603+ prompts = []
604+ completions = []
605+ for prompt , completion in _dataset :
606+ prompts .append (prompt )
607+ completions .append (completion )
608+ if apply_chat_template :
609+ assert (
610+ getattr (tokenizer , "chat_template" , None ) is not None
611+ ), '"--apply-chat-template" is set but the tokenizer does not have chat template.'
612+ prompts = [
613+ tokenizer .apply_chat_template (
614+ [{"role" : "user" , "content" : prompt }],
615+ add_generation_prompt = True ,
616+ tokenize = False ,
617+ )
618+ for prompt in prompts
619+ ]
620+
621+ prompt_token_ids = list (
622+ tokenizer (
623+ prompts ,
624+ truncation = True ,
625+ max_length = min (tokenizer .model_max_length , self .truncate_length ),
626+ add_special_tokens = False ,
627+ ).input_ids
628+ )
629+ completion_token_ids = tokenizer (
630+ completions ,
631+ truncation = True ,
632+ max_length = min (tokenizer .model_max_length , self .truncate_length ),
633+ add_special_tokens = False ,
634+ ).input_ids
635+ self ._tokenized_dataset : List [Tuple [str , List [int ], int ]] = []
636+ for i in range (len (_dataset )):
637+ if len (prompt_token_ids [i ]) < 4 or len (completion_token_ids [i ]) < 4 :
638+ # Filter out sequences that are too short
639+ continue
640+ self ._tokenized_dataset .append (
641+ (prompts [i ], prompt_token_ids [i ], len (completion_token_ids [i ]))
642+ )
643+
644+ def generate_request_records ( # pylint: disable=too-many-locals
645+ self ,
646+ input_len : Optional [int ],
647+ output_len : Optional [int ],
648+ input_len_std : float = 0.0 ,
649+ output_len_std : float = 0.0 ,
650+ ) -> List [RequestRecord ]:
651+ if self .apply_chat_template :
652+ assert (
653+ input_len is None
654+ ), '"--apply-chat-template" is not supported when "--input-len" is specified.'
655+
656+ request_records = []
657+ for prompt , input_token_ids , output_length in self ._tokenized_dataset :
658+ input_length = len (input_token_ids )
659+ # If the request does not have enough length, discard it.
660+ if input_len is not None and input_length < input_len + 4 * input_len_std :
661+ continue
662+
663+ if input_len is not None :
664+ input_length = round (
665+ float (np .random .normal (loc = input_len , scale = input_len_std , size = 1 )[0 ])
666+ )
667+ input_token_ids = input_token_ids [:input_length ]
668+ input_truncated = True
669+ else :
670+ input_truncated = False
671+ if output_len is not None :
672+ output_length = round (
673+ float (np .random .normal (loc = output_len , scale = output_len_std , size = 1 )[0 ])
674+ )
675+ elif output_length <= 1 :
676+ continue
677+ request_records .append (
678+ RequestRecord (
679+ chat_cmpl = ChatCompletionRequest (
680+ messages = [
681+ {
682+ "role" : "user" ,
683+ "content" : (
684+ self .tokenizer .decode (input_token_ids )
685+ if input_truncated
686+ else prompt
687+ ),
688+ }
689+ ],
690+ model = "" ,
691+ max_tokens = output_length ,
692+ ),
693+ metrics = Metrics (
694+ success = False ,
695+ start_time = 0 ,
696+ finish_time = 0 ,
697+ end_to_end_latency_s = 0 ,
698+ input_tokens = len (input_token_ids ),
699+ ),
700+ )
701+ )
702+ return request_records
703+
769704
770705SUPPORTED_DATASET = [
771706 "sharegpt" ,
772707 "llmperf" ,
773708 "json-mode-eval" ,
774709 "loogle" ,
775710 "react" ,
711+ "wildchat" ,
776712]
777713
778714
@@ -811,4 +747,6 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
811747 args .apply_chat_template is False
812748 ), "ReAct dataset does not support applying chat template"
813749 return ReActDataset (args .dataset_path , tokenizer )
750+ if args .dataset == "wildchat" :
751+ return WildChatDataset (tokenizer , args .apply_chat_template )
814752 raise ValueError (f"Unrecognized dataset { args .dataset } " )
0 commit comments