@@ -581,198 +581,134 @@ def generate_request_records(
581
581
return request_records
582
582
583
583
584
- # Todo: dataset of log replay # pylint: disable=fixme
585
- # NOTE: moved from the previous "python/mlc_llm/bench/prompts.py"
586
- # class PromptsGenerator: # pylint: disable=too-few-public-methods
587
- # """
588
- # Generates prompts of a specified token length from a text file containing potential prompts.
589
- # """
590
-
591
- # def __init__(
592
- # self,
593
- # prompts_path: Optional[str] = None,
594
- # json_prompts_path: Optional[str] = None,
595
- # tokenizer: Optional[Any] = None,
596
- # seed: Optional[int] = 11111,
597
- # ) -> None:
598
- # """
599
- # Initializes the PromptsGenerator with the file path and tokenizer.
600
-
601
- # Parameters
602
- # ----------
603
- # prompts_path : Optional[str]
604
- # The path to the file containing the source prompts. This file can be
605
- # a plain text file, with each line representing a separate prompt str,
606
- # or a .jsonl file where each line is a JSON object formatted as
607
- # {"prompt": "prompt text", "input_tokens": 10}.
608
-
609
- # json_prompts_path : Optional[str]
610
- # The path to the file containing the source json prompts. This file a
611
- # .jsonl file where each line is a JSON object formatted as
612
- # {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
613
-
614
- # tokenizer : Optional[Any]
615
- # The tokenizer object to use for tokenizing the prompts.
616
-
617
- # seed : Optional[int]
618
- # The seed for the random number generator.
619
- # """
620
- # random.seed(seed)
621
- # self.tokenizer = tokenizer
622
- # if not self.tokenizer:
623
- # from transformers import ( # pylint: disable=import-outside-toplevel,import-error
624
- # LlamaTokenizerFast,
625
- # )
626
-
627
- # self.tokenizer = LlamaTokenizerFast.from_pretrained(
628
- # "hf-internal-testing/llama-tokenizer"
629
- # )
630
- # logger.warning("No tokenizer provided. Using default tokenizer.")
631
-
632
- # self.prompts: List[Dict] = []
633
- # if prompts_path is not None and prompts_path.endswith(".jsonl"):
634
- # with open(prompts_path, "r", encoding="utf-8") as file:
635
- # for line in file:
636
- # json_line = json.loads(line)
637
- # assert "prompt" in json_line, "The prompt field is required in the JSONL file"
638
- # if "input_tokens" not in json_line:
639
- # json_line["input_tokens"] = self._count_tokens(json_line["prompt"])
640
- # self.prompts.append(json_line)
641
- # else:
642
- # if not prompts_path:
643
- # prompts_path = Path(__file__).parent / "prompts.txt" # type: ignore
644
- # with open(prompts_path, "r", encoding="utf-8") as file:
645
- # prompt_line = file.readline()
646
- # input_tokens = self._count_tokens(prompt_line)
647
- # self.prompts.append({"prompt": prompt_line, "input_tokens": input_tokens})
648
- # if json_prompts_path:
649
- # self.json_prompts = defaultdict(list)
650
- # with open(json_prompts_path, "r", encoding="utf-8") as file:
651
- # for line in file:
652
- # json_line = json.loads(line)
653
- # assert (
654
- # "messages" in json_line
655
- # ), "The messages field is required in the JSONL file."
656
- # assert (
657
- # "response_format" in json_line
658
- # ), "The response_format field is required in the JSONL file."
659
- # self.json_prompts[json.dumps(json_line["response_format"]["schema"])].append(
660
- # json_line["messages"]
661
- # )
662
- # else:
663
- # self.json_prompts = None
664
-
665
- # def _count_tokens(self, text: str) -> int:
666
- # """Get the number of tokens.
667
-
668
- # Parameters
669
- # ----------
670
- # text : str
671
- # The text to tokenize.
672
-
673
- # Returns
674
- # -------
675
- # output : int
676
- # The number of tokens
677
- # """
678
- # return len(self.tokenizer.encode(text))
679
-
680
- # def generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
681
- # """
682
- # Generates a prompt based on the params, e.g. input_tokens, response_format.
683
-
684
- # Parameters
685
- # ----------
686
- # params : Dict[str, Any]
687
- # The desired mean number of tokens in the prompt.
688
-
689
- # Returns
690
- # -------
691
- # override_params: Dict[str, Any]
692
- # The params to override the original request, e.g. messages, response_format.
693
- # """
694
- # if "response_format" in params:
695
- # response_format = params["response_format"]
696
- # if response_format.get("type") == "json_object":
697
- # if response_format.get("schema") in self.json_prompts:
698
- # assert len(self.json_prompts[response_format["schema"]]) > 0
699
- # return {"messages":
700
- # random.choice(self.json_prompts[response_format["schema"]])}
701
- # schema, prompts = random.choice(list(self.json_prompts.items()))
702
- # response_format["schema"] = schema
703
- # return {"messages": random.choice(prompts), "response_format": response_format}
704
- # tokens_mean = params.get("input_tokens", 128)
705
- # assert tokens_mean > 0, "The mean number of tokens must be greater than 0."
706
- # remaining_input_tokens = tokens_mean
707
- # result_prompt = ""
708
- # override_params = None
709
- # while remaining_input_tokens > 0:
710
- # prompt_dict = random.choice(self.prompts)
711
- # cur_input_tokens = prompt_dict["input_tokens"]
712
- # cur_prompt = prompt_dict["prompt"]
713
- # if override_params is None:
714
- # override_params = prompt_dict["override_params"]
715
- # if remaining_input_tokens - cur_input_tokens < 0:
716
- # result_prompt += cur_prompt[:remaining_input_tokens]
717
- # remaining_input_tokens = 0
718
- # break
719
- # result_prompt += cur_prompt
720
- # remaining_input_tokens -= cur_input_tokens
721
- # return {"messages": [{"role": "system", "content": result_prompt}]}
722
-
723
-
724
- # def load_replay_log(log_path: str) -> List[Dict]:
725
- # """
726
- # Load replay log from file
727
-
728
- # Parameters
729
- # ----------
730
- # log_path : str
731
- # The path to the event log CSV or JSONL file containing the events to replay.
732
-
733
- # Returns
734
- # -------
735
- # res: List[Dict]
736
- # A list of preprocessed event data for replay.
737
- # """
738
- # if log_path.endswith(".csv"):
739
- # import pandas as pd # pylint: disable=import-outside-toplevel,import-error
740
-
741
- # df = pd.read_csv(log_path)
742
- # column_names = df.columns.values
743
- # assert (
744
- # ("Date" in column_names)
745
- # and ("@request" in column_names)
746
- # and ("Message" in column_names)
747
- # )
748
- # df["timestamp"] = pd.to_datetime(df["Date"])
749
- # df.sort_values("timestamp", inplace=True)
750
- # # Get the request params from the loaded CSV
751
- # params = []
752
- # for _, row in df.iterrows():
753
- # request = row["@request"]
754
- # payload = json.loads(str(request))
755
- # params.append(
756
- # {
757
- # "timestamp": row["timestamp"],
758
- # "payload": payload,
759
- # }
760
- # )
761
- # return params
762
- # if log_path.endswith(".jsonl"):
763
- # with open(log_path, "r", encoding="utf-8") as file:
764
- # data = [json.loads(line) for line in file]
765
- # for row in data:
766
- # row["timestamp"] = datetime.fromisoformat(str(row["timestamp"]))
767
- # return data
768
- # raise ValueError("Unsupported file format. Please use .csv or .jsonl.")
584
+ class WildChatDataset (Dataset ): # pylint: disable=too-few-public-methods
585
+ """The dataset class for WildChat dataset."""
586
+
587
+ apply_chat_template : bool
588
+
589
+ def __init__ (self , tokenizer : AutoTokenizer , apply_chat_template : bool ) -> None :
590
+ raw_dataset = load_dataset ("allenai/WildChat" , split = "train" )
591
+ self .tokenizer = tokenizer
592
+ self .apply_chat_template = apply_chat_template
593
+
594
+ # Filter out the conversations with less than 2 turns.
595
+ _dataset = [
596
+ (entry ["conversation" ][0 ]["content" ], entry ["conversation" ][1 ]["content" ])
597
+ for entry in raw_dataset
598
+ if len (entry ["conversation" ]) >= 2
599
+ and entry ["conversation" ][0 ]["role" ] == "user"
600
+ and entry ["conversation" ][1 ]["role" ] == "assistant"
601
+ ]
602
+
603
+ prompts = []
604
+ completions = []
605
+ for prompt , completion in _dataset :
606
+ prompts .append (prompt )
607
+ completions .append (completion )
608
+ if apply_chat_template :
609
+ assert (
610
+ getattr (tokenizer , "chat_template" , None ) is not None
611
+ ), '"--apply-chat-template" is set but the tokenizer does not have chat template.'
612
+ prompts = [
613
+ tokenizer .apply_chat_template (
614
+ [{"role" : "user" , "content" : prompt }],
615
+ add_generation_prompt = True ,
616
+ tokenize = False ,
617
+ )
618
+ for prompt in prompts
619
+ ]
620
+
621
+ prompt_token_ids = list (
622
+ tokenizer (
623
+ prompts ,
624
+ truncation = True ,
625
+ max_length = min (tokenizer .model_max_length , self .truncate_length ),
626
+ add_special_tokens = False ,
627
+ ).input_ids
628
+ )
629
+ completion_token_ids = tokenizer (
630
+ completions ,
631
+ truncation = True ,
632
+ max_length = min (tokenizer .model_max_length , self .truncate_length ),
633
+ add_special_tokens = False ,
634
+ ).input_ids
635
+ self ._tokenized_dataset : List [Tuple [str , List [int ], int ]] = []
636
+ for i in range (len (_dataset )):
637
+ if len (prompt_token_ids [i ]) < 4 or len (completion_token_ids [i ]) < 4 :
638
+ # Filter out sequences that are too short
639
+ continue
640
+ self ._tokenized_dataset .append (
641
+ (prompts [i ], prompt_token_ids [i ], len (completion_token_ids [i ]))
642
+ )
643
+
644
+ def generate_request_records ( # pylint: disable=too-many-locals
645
+ self ,
646
+ input_len : Optional [int ],
647
+ output_len : Optional [int ],
648
+ input_len_std : float = 0.0 ,
649
+ output_len_std : float = 0.0 ,
650
+ ) -> List [RequestRecord ]:
651
+ if self .apply_chat_template :
652
+ assert (
653
+ input_len is None
654
+ ), '"--apply-chat-template" is not supported when "--input-len" is specified.'
655
+
656
+ request_records = []
657
+ for prompt , input_token_ids , output_length in self ._tokenized_dataset :
658
+ input_length = len (input_token_ids )
659
+ # If the request does not have enough length, discard it.
660
+ if input_len is not None and input_length < input_len + 4 * input_len_std :
661
+ continue
662
+
663
+ if input_len is not None :
664
+ input_length = round (
665
+ float (np .random .normal (loc = input_len , scale = input_len_std , size = 1 )[0 ])
666
+ )
667
+ input_token_ids = input_token_ids [:input_length ]
668
+ input_truncated = True
669
+ else :
670
+ input_truncated = False
671
+ if output_len is not None :
672
+ output_length = round (
673
+ float (np .random .normal (loc = output_len , scale = output_len_std , size = 1 )[0 ])
674
+ )
675
+ elif output_length <= 1 :
676
+ continue
677
+ request_records .append (
678
+ RequestRecord (
679
+ chat_cmpl = ChatCompletionRequest (
680
+ messages = [
681
+ {
682
+ "role" : "user" ,
683
+ "content" : (
684
+ self .tokenizer .decode (input_token_ids )
685
+ if input_truncated
686
+ else prompt
687
+ ),
688
+ }
689
+ ],
690
+ model = "" ,
691
+ max_tokens = output_length ,
692
+ ),
693
+ metrics = Metrics (
694
+ success = False ,
695
+ start_time = 0 ,
696
+ finish_time = 0 ,
697
+ end_to_end_latency_s = 0 ,
698
+ input_tokens = len (input_token_ids ),
699
+ ),
700
+ )
701
+ )
702
+ return request_records
703
+
769
704
770
705
SUPPORTED_DATASET = [
771
706
"sharegpt" ,
772
707
"llmperf" ,
773
708
"json-mode-eval" ,
774
709
"loogle" ,
775
710
"react" ,
711
+ "wildchat" ,
776
712
]
777
713
778
714
@@ -811,4 +747,6 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
811
747
args .apply_chat_template is False
812
748
), "ReAct dataset does not support applying chat template"
813
749
return ReActDataset (args .dataset_path , tokenizer )
750
+ if args .dataset == "wildchat" :
751
+ return WildChatDataset (tokenizer , args .apply_chat_template )
814
752
raise ValueError (f"Unrecognized dataset { args .dataset } " )
0 commit comments