diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index f17b2270..8e05471b 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -14,7 +14,6 @@ # ============================================================================ import os from typing import Tuple, List - import sys sys.stdout = open(sys.stdout.fileno(), mode="w", buffering=1) @@ -88,7 +87,8 @@ def boolean_string(string): parser.add_argument("--prompt_path", type=str, default="prompt.json", help="Path to model file") parser.add_argument("--padding", help="Enable padding, Default to True.", type=boolean_string, default=True) parser.add_argument("--csv", type=str, default="", help="Path to csv file") - +parser.add_argument("--sonnet_prefix_len", type=int, default=200, help="sonnet dataset Prefix length") +parser.add_argument("--sonnet_count", type=int, default=20, help="sonnet dataset prompt count") def build_inputs_chatglm(tokenizer, query: List[str], padding, history: List[Tuple[str, str]] = []): prompts = [] @@ -126,11 +126,97 @@ def get_inputs(args, prompt_pool, tokenizer): max_lens.append(max_len) return inputs, max_lens + + +class SonnetDataset: + """ + Copy from vllm + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + dataset_path: str = None, + random_seed: int = 47 + ) -> None: + self.dataset_path = dataset_path + self.random_seed = random_seed + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = round((prefix_len - base_offset) / avg_len) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + for _ in range(num_requests): + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + # prompt_len = len(tokenizer(prompt_formatted).input_ids) + samples.append(prompt_formatted) + return samples + + if __name__ == "__main__": args = parser.parse_args() + tokenizer = AutoTokenizer.from_pretrained( + args.token_path, use_fast=False, padding_side="left", trust_remote_code=True, legacy=False + ) if args.prompt_path.endswith('.txt'): with open(args.prompt_path, "r") as txt_file: prompt_pool = txt_file.read().splitlines() + elif args.prompt_path.endswith('.sonnet'): + sonnet = SonnetDataset(dataset_path=args.prompt_path) + prompt_pool = sonnet.sample( + tokenizer=tokenizer, + num_requests=int(args.sonnet_count), + prefix_len=int(args.sonnet_prefix_len), + input_len=int(args.token_in), + output_len=int(args.token_out), + ) elif args.prompt_path.endswith('.json'): with open(args.prompt_path, "r") as json_file: prompt_pool = json.load(json_file) @@ -178,10 +264,6 @@ def get_inputs(args, prompt_pool, tokenizer): if "deepseek" in args.model_name.lower(): prompt_pool = prompt_pool["deepseek"] - tokenizer = AutoTokenizer.from_pretrained( - args.token_path, use_fast=False, padding_side="left", trust_remote_code=True, legacy=False - ) - try: import xfastertransformer diff --git a/benchmark/run_benchmark.sh b/benchmark/run_benchmark.sh index 1a172d63..8d6b9ae6 100755 --- a/benchmark/run_benchmark.sh +++ b/benchmark/run_benchmark.sh @@ -116,6 +116,14 @@ while [ -n "$1" ]; do prompt_path=$2 shift 2 ;; + -splen | --sonnet_prefix_len) + sonnet_prefix_len=$2 + shift 2 + ;; + -sc | --sonnet_count) + sonnet_count=$2 + shift 2 + ;; "") shift break @@ -174,6 +182,14 @@ benchmark_cmd="python "${SCRIPT_DIR}"/benchmark.py \ --iteration ${iter} \ --warmup ${warmup}" +if [ -n $sonnet_prefix_len ]; then + benchmark_cmd+=" --sonnet_prefix_len ${sonnet_prefix_len}" +fi + +if [ -n $sonnet_count ]; then + benchmark_cmd+=" --sonnet_count ${sonnet_count}" +fi + if [ -n $csv ]; then benchmark_cmd+=" --csv=$csv" fi