diff --git a/Dockerfile b/Dockerfile index c658cfc..0ddd15a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9.20-slim-bookworm as dev +FROM python:3.9.20-slim-bookworm AS dev RUN apt-get update -y \ && apt-get install -y python3-pip git vim curl wget diff --git a/benchmark_serving.py b/benchmark_serving.py index 7919c89..3b18f56 100644 --- a/benchmark_serving.py +++ b/benchmark_serving.py @@ -42,6 +42,201 @@ from google.protobuf.timestamp_pb2 import Timestamp + + +import os +import sys +import uuid +import traceback +from google.cloud import spanner +import math +from google.api_core import exceptions as gcp_exceptions + +def safe_json_value(value, default=0.0): + """Convert value to JSON-safe format, handling NaN and Infinity.""" + if value is None: + return default + if isinstance(value, (int, float)): + if math.isnan(value) or math.isinf(value): + return default + return value + return value + +def extract_proto_fields(data, run_type): + """Extract and structure relevant fields for Spanner insertion, including `run_type`.""" + + config = { + 'model': data.get('config', {}).get('model', ''), + 'num_models': safe_json_value(data.get('config', {}).get('num_models', 0), 0), + 'model_server': data.get('config', {}).get('model_server', ''), + 'backend': data.get('dimensions', {}).get('backend', ''), + 'model_id': data.get('dimensions', {}).get('model_id', ''), + 'tokenizer_id': data.get('dimensions', {}).get('tokenizer_id', ''), + 'request_rate': safe_json_value(data.get('metrics', {}).get('request_rate', 0), 0), + 'benchmark_time': safe_json_value(data.get('metrics', {}).get('benchmark_time', 0), 0), + 'run_type': run_type + } + + infrastructure = { + 'model_server': config['model_server'], + 'backend': config['backend'], + 'gpu_cache_usage_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:gpu_cache_usage_perc', {}).get('P90', 0.0)), + 'num_requests_waiting_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:num_requests_waiting', {}).get('P90', 0.0)), + 'gpu_cache_usage_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:gpu_cache_usage_perc', {}).get('Mean', 0.0)), + 'num_requests_waiting_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:num_requests_waiting', {}).get('Mean', 0.0)), + } + + metrics = data.get('metrics', {}) + prompt_dataset = { + 'num_prompts_attempted': safe_json_value(metrics.get('num_prompts_attempted', 0), 0), + 'num_prompts_succeeded': safe_json_value(metrics.get('num_prompts_succeeded', 0), 0), + 'avg_input_len': safe_json_value(metrics.get('avg_input_len', 0.0)), + 'median_input_len': safe_json_value(metrics.get('median_input_len', 0.0)), + 'p90_input_len': safe_json_value(metrics.get('p90_input_len', 0.0)), + 'avg_output_len': safe_json_value(metrics.get('avg_output_len', 0.0)), + 'median_output_len': safe_json_value(metrics.get('median_output_len', 0.0)), + 'p90_output_len': safe_json_value(metrics.get('p90_output_len', 0.0)) + } + + summary_stats = { + 'p90_normalized_time_per_output_token_ms': safe_json_value(metrics.get('p90_normalized_time_per_output_token_ms', 0.0)), + 'avg_normalized_time_per_output_token_ms': safe_json_value(metrics.get('avg_normalized_time_per_output_token_ms', 0.0)), + 'throughput': safe_json_value(metrics.get('throughput', 0.0)), + 'input_tokens_per_sec': safe_json_value(metrics.get('input_tokens_per_sec', 0.0)), + 'benchmark_time': safe_json_value(metrics.get('benchmark_time', 0.0)), + 'date': data.get('dimensions', {}).get('date', ''), + 'avg_latency_ms': safe_json_value(metrics.get('avg_latency_ms', 0.0)), + 'median_latency_ms': safe_json_value(metrics.get('median_latency_ms', 0.0)), + 'p90_latency_ms': safe_json_value(metrics.get('p90_latency_ms', 0.0)), + 'p99_latency_ms': safe_json_value(metrics.get('p99_latency_ms', 0.0)), + 'time_per_output_token_seconds_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_per_output_token_seconds', {}).get('P90', 0.0)), + 'time_to_first_token_seconds_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_to_first_token_seconds', {}).get('P90', 0.0)), + 'time_per_output_token_seconds_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_per_output_token_seconds', {}).get('Mean', 0.0)), + 'time_to_first_token_seconds_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_to_first_token_seconds', {}).get('Mean', 0.0)), + } + + return config, infrastructure, prompt_dataset, summary_stats + +def clean_for_json(obj): + """Recursively clean an object for JSON serialization.""" + if isinstance(obj, dict): + return {k: clean_for_json(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [clean_for_json(item) for item in obj] + elif isinstance(obj, float): + if math.isnan(obj) or math.isinf(obj): + return 0.0 + return obj + elif obj is None: + return 0.0 + else: + return obj + +def upload_to_spanner_batch_with_retry(instance_id, database_id, json_files, gcs_base_uri, run_type, max_retries=3): + """ + Upload JSON files to Spanner in batches with retry logic. + More efficient but fails entire batch if any single file has issues. + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + print(f"📊 Uploading {len(json_files)} JSON files to Spanner with run_type='{run_type}'...") + + retry_count = 0 + success = False + processed_files = [] + + while retry_count <= max_retries and not success: + try: + processed_files = [] # Reset on each retry + + with database.batch() as batch: + for json_file in json_files: + try: + with open(json_file, 'r') as f: + data = json.load(f) + + config, infra, prompt, stats = extract_proto_fields(data, run_type) + filename = os.path.basename(json_file) + gcs_uri = f"{gcs_base_uri}/{filename}" + latency_profile_id = str(uuid.uuid4()) + + # Test JSON serialization before inserting + try: + config_clean = clean_for_json(config) + infra_clean = clean_for_json(infra) + prompt_clean = clean_for_json(prompt) + stats_clean = clean_for_json(stats) + + config_json = json.dumps(config_clean) + infra_json = json.dumps(infra_clean) + prompt_json = json.dumps(prompt_clean) + stats_json = json.dumps(stats_clean) + except (TypeError, ValueError) as json_error: + print(f"❌ JSON serialization failed for {json_file}: {json_error}") + continue + + batch.insert( + table='LatencyProfiles', + columns=['Id', 'Config', 'Infrastructure', 'PromptDataset', 'SummaryStats', 'GcsUri', 'InsertedAt'], + values=[ + (latency_profile_id, config_json, infra_json, prompt_json, stats_json, gcs_uri, spanner.COMMIT_TIMESTAMP) + ] + ) + + if 'core_deployment_artifacts' in data or 'extension_deployment_artifacts' in data: + core_json = json.dumps(data.get('core_deployment_artifacts', {})) + ext_json = json.dumps(data.get('extension_deployment_artifacts', {})) + batch.insert( + table='DeploymentArtifacts', + columns=['Id', 'LatencyProfileId', 'CoreDeploymentArtifacts', 'ExtensionDeploymentArtifacts'], + values=[ + (str(uuid.uuid4()), latency_profile_id, core_json, ext_json) + ] + ) + + processed_files.append((json_file, latency_profile_id)) + + except Exception as e: + print(f"❌ Failed to process {json_file}: {e}") + continue + + # If we get here, the batch committed successfully + for json_file, profile_id in processed_files: + print(f"✅ {json_file} uploaded (ID: {profile_id})") + success = True + + except (gcp_exceptions.DeadlineExceeded, + gcp_exceptions.ServiceUnavailable, + gcp_exceptions.InternalServerError, + gcp_exceptions.TooManyRequests) as retryable_error: + retry_count += 1 + if retry_count <= max_retries: + wait_time = (2 ** retry_count) + random.uniform(0, 1) + print(f"⚠️ Batch upload failed (attempt {retry_count}/{max_retries}): {retryable_error}") + print(f"⏳ Retrying entire batch in {wait_time:.1f} seconds...") + time.sleep(wait_time) + else: + print(f"❌ Batch upload failed after {max_retries} retries: {retryable_error}") + + except Exception as batch_error: + retry_count += 1 + if retry_count <= max_retries: + wait_time = (2 ** retry_count) + random.uniform(0, 1) + print(f"⚠️ Unexpected batch error (attempt {retry_count}/{max_retries}): {batch_error}") + print(f"⏳ Retrying entire batch in {wait_time:.1f} seconds...") + time.sleep(wait_time) + else: + print(f"❌ Batch upload failed after {max_retries} retries: {batch_error}") + traceback.print_exc() + + if success: + print("✅ All files uploaded successfully.") + else: + print("❌ Upload process failed after all retries.") + + MIN_SEQ_LEN = 4 NEW_TEXT_KEY = "\nOutput:\n" PROMETHEUS_PORT = 9090 @@ -101,6 +296,8 @@ def get_filtered_dataset( dataset_path: str, max_input_len: int, max_output_len: int, + min_input_len: int, + min_output_len: int, tokenizer: PreTrainedTokenizerBase, use_dummy_text: bool, ) -> List[Tuple[str, int, int]]: @@ -139,7 +336,7 @@ def get_filtered_dataset( filtered_dataset: List[Tuple[str, int, int]] = [] for prompt, prompt_token_ids, output_len in tokenized_dataset: prompt_len = len(prompt_token_ids) - if prompt_len < MIN_SEQ_LEN or output_len < MIN_SEQ_LEN: + if prompt_len < min_input_len or output_len < min_output_len: # Prune too short sequences. # This is because TGI causes errors when the input or output length # is too short. @@ -194,12 +391,16 @@ async def send_stream_request( model: str, timeout: float, max_conn: int, + inference_objective: str, + target_model_header: Optional[str] = None, ) -> Tuple[Tuple[int, int, float], float, List[float], Dict[str, int]]: """Sends stream request to server""" request_start_time_ms = 1000 * time.time() errors = init_errors_map() - headers = {"User-Agent": "Benchmark Client"} + headers = {"User-Agent": "Benchmark Client", "x-gateway-inference-objective": inference_objective} + if target_model_header: + headers["x-gateway-model-name-rewrite"] = target_model_header if backend == "vllm": pload = { "model": model, @@ -303,12 +504,16 @@ async def send_request( model: str, timeout: float, max_conn: int, + inference_objective: str, + target_model_header: Optional[str] = None, ) -> Tuple[Tuple[int, int, float], float, List[float], Dict[str, int]]: """Sends request to server.""" request_start_time_ms = 1000 * time.time() errors = init_errors_map() - headers = {"User-Agent": "Benchmark Client"} + headers = {"User-Agent": "Benchmark Client", "x-gateway-inference-objective": inference_objective} + if target_model_header: + headers["x-gateway-model-name-rewrite"] = target_model_header if backend == "vllm": pload = { "model": model, @@ -447,17 +652,17 @@ async def send_request( async def run_single_request(args: argparse.Namespace, api_url: str, tokenizer: PreTrainedTokenizerBase, - prompt: str, prompt_len: int, output_len: int, chosen_model: str) -> Tuple[str, Tuple]: + prompt: str, prompt_len: int, output_len: int, chosen_model: str, target_model_header: Optional[str],) -> Tuple[str, Tuple]: if args.stream_request: result = await send_stream_request( args.backend, api_url, prompt, prompt_len, output_len, args.ignore_eos, args.best_of, args.use_beam_search, args.top_k, tokenizer, args.sax_model, - chosen_model, args.request_timeout, args.tcp_conn_limit) + chosen_model, args.request_timeout, args.tcp_conn_limit, args.inference_objective, target_model_header=target_model_header) else: result = await send_request( args.backend, api_url, prompt, prompt_len, output_len, args.ignore_eos, args.best_of, args.use_beam_search, args.top_k, tokenizer, args.sax_model, - chosen_model, args.request_timeout, args.tcp_conn_limit) + chosen_model, args.request_timeout, args.tcp_conn_limit, args.inference_objective, target_model_header=target_model_header) return chosen_model, result async def benchmark( @@ -471,7 +676,7 @@ async def benchmark( Also saves results separately for each model. """ input_requests = get_filtered_dataset( - args.dataset, args.max_input_length, args.max_output_length, tokenizer, args.use_dummy_text) + args.dataset, args.max_input_length, args.max_output_length, args.min_input_length, args.min_output_length, tokenizer, args.use_dummy_text) # Combine the models list and traffic split list into a dict @@ -486,6 +691,7 @@ async def benchmark( models_dict = dict(zip(models, traffic_split)) model_names = list(models_dict.keys()) model_weights = list(models_dict.values()) + target_map = args.targetmodels or {} benchmark_start_time_sec = time.time() # Initialize the counter with target prompts @@ -497,7 +703,8 @@ async def benchmark( break prompt, prompt_len, output_len = request chosen_model = random.choices(model_names, weights=model_weights)[0] - task = asyncio.create_task(run_single_request(args, api_url, tokenizer, prompt, prompt_len, output_len, chosen_model)) + target_model_header = target_map.get(chosen_model) + task = asyncio.create_task(run_single_request(args, api_url, tokenizer, prompt, prompt_len, output_len, chosen_model, target_model_header=target_model_header)) tasks.append(task) prompts_sent += 1 @@ -534,13 +741,13 @@ async def benchmark( await print_and_save_result(args, benchmark_duration_sec, prompts_sent, "weighted", overall_results["latencies"], overall_results["ttfts"], overall_results["itls"], overall_results["tpots"], - overall_results["errors"]) + overall_results["errors"], spanner_upload=True, server_metrics_scrape=True) for model, data in per_model_results.items(): await print_and_save_result(args, benchmark_duration_sec, len(data["latencies"]), model, data["latencies"], data["ttfts"], data["itls"], data["tpots"], data["errors"]) -def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors): +def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors, spanner_upload: bool = False): # Setup start_dt_proto = Timestamp() start_dt_proto.FromDatetime(args.start_datetime) @@ -636,6 +843,17 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics print(f"File {file_name} uploaded to gs://{args.output_bucket}/{args.output_bucket_filepath}") except google.cloud.exceptions.NotFound: print(f"GS Bucket (gs://{args.output_bucket}) does not exist") + + if args.spanner_instance_id and args.spanner_database_id and spanner_upload: + # Upload to Spanner + try: + upload_to_spanner_batch_with_retry( + args.spanner_instance_id, args.spanner_database_id, [file_name], + args.output_bucket, args.file_prefix) + print(f"File {file_name} uploaded to Spanner") + except Exception as e: + print(f"Failed to upload {file_name} to Spanner: {e}") + def metrics_to_scrape(backend: str) -> List[str]: # Each key in the map is a metric, it has a corresponding 'stats' object @@ -815,7 +1033,7 @@ def get_stats_for_set(name, description, points): f'p99_{name}': p99, } -async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors): +async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors, spanner_upload=False, server_metrics_scrape=False): benchmark_result = {} print(f"====Result for Model: {model}====") @@ -882,10 +1100,10 @@ async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec } server_metrics = {} - if args.scrape_server_metrics: + if args.scrape_server_metrics and server_metrics_scrape: server_metrics = print_metrics(metrics_to_scrape(args.backend), benchmark_duration_sec, args.pm_namespace, args.pm_job) if args.save_json_results: - save_json_results(args, benchmark_result, server_metrics, model, errors) + save_json_results(args, benchmark_result, server_metrics, model, errors, spanner_upload) async def main(args: argparse.Namespace): print(args) @@ -930,7 +1148,22 @@ async def main(args: argparse.Namespace): await benchmark(args, api_url, tokenizer,models, args.traffic_split) - +def parse_targetmodels(arg: Optional[str]): + """Parse mappings like 'srcA:dstA,srcB:dstB' into a dict.""" + if arg is None: + return {} + mapping = {} + try: + for item in arg.split(','): + if not item.strip(): + continue + src, dst = item.split(':', 1) + mapping[src.strip()] = dst.strip() + except ValueError: + raise argparse.ArgumentTypeError( + "targetmodels must be 'src:dst,src2:dst2' (comma-separated 'src:dst' pairs)." + ) + return mapping def parse_traffic_split(arg): try: @@ -973,6 +1206,22 @@ def parse_traffic_split(arg): type=str, help="Comma separated list of models to benchmark.", ) + + parser.add_argument( + "--inference-objective", + type=str, + default="base-model", + help="Value for the 'x-gateway-inference-objective' header (e.g., 'base-model', 'slo-aware').", + ) + + parser.add_argument( + "--targetmodels", + type=parse_targetmodels, + default=None, + help="Optional mapping 'src_model:target_model,src2:target2'. If present and a chosen model matches a key, " + "the request header 'x-gateway-model-name-rewrite' will be set to the mapped target value." +) + parser.add_argument( "--traffic-split", type=parse_traffic_split, @@ -1022,7 +1271,23 @@ def parse_traffic_split(arg): type=int, default=1024, help=( - "Maximum number of input tokens for filtering the benchmark dataset." + "Maximum number of output tokens for filtering the benchmark dataset." + ), + ) + parser.add_argument( + "--min-input-length", + type=int, + default=4, + help=( + "Minimum number of input tokens for filtering the benchmark dataset." + ), + ) + parser.add_argument( + "--min-output-length", + type=int, + default=4, + help=( + "Minimum number of output tokens for filtering the benchmark dataset." ), ) parser.add_argument( @@ -1118,6 +1383,18 @@ def parse_traffic_split(arg): action="store_true", help="Whether to scrape server metrics.", ) + parser.add_argument( + "--spanner-instance-id", + type=str, + default=None, + help="Spanner instance ID to upload results to.", + ) + parser.add_argument( + "--spanner-database-id", + type=str, + default=None, + help="Spanner database ID to upload results to.", + ) parser.add_argument("--pm-namespace", type=str, default="default", help="namespace of the pod monitoring object, ignored if scrape-server-metrics is false") parser.add_argument("--pm-job", type=str, default="vllm-podmonitoring", help="name of the pod monitoring object, ignored if scrape-server-metrics is false") parser.add_argument("--tcp-conn-limit", type=int, default=100, help="Max number of tcp connections allowed per aiohttp ClientSession") diff --git a/latency_throughput_curve.sh b/latency_throughput_curve.sh index 69db546..383ddbb 100755 --- a/latency_throughput_curve.sh +++ b/latency_throughput_curve.sh @@ -1,27 +1,15 @@ #!/bin/bash - -# Copyright 2024 Google Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +set -euo pipefail set -o xtrace -export IP=$IP +export IP=${IP:-localhost} -huggingface-cli login --token "$HF_TOKEN" --add-to-git-credential +huggingface-cli login --token "${HF_TOKEN:-}" --add-to-git-credential || true -if [[ "$PROMPT_DATASET" = "sharegpt" ]]; then +if [[ "${PROMPT_DATASET:-}" == "sharegpt" ]]; then PROMPT_DATASET_FILE="ShareGPT_V3_unfiltered_cleaned_split.json" +else + PROMPT_DATASET_FILE="${PROMPT_DATASET_FILE:-$PROMPT_DATASET}" fi PYTHON="python3" @@ -29,7 +17,7 @@ BASE_PYTHON_OPTS=( "benchmark_serving.py" "--save-json-results" "--host=$IP" - "--port=$PORT" + "--port=${PORT:-7080}" "--dataset=$PROMPT_DATASET_FILE" "--tokenizer=$TOKENIZER" "--backend=$BACKEND" @@ -37,43 +25,54 @@ BASE_PYTHON_OPTS=( "--max-output-length=$OUTPUT_LENGTH" "--file-prefix=$FILE_PREFIX" "--models=$MODELS" - "--pm-namespace=$PM_NAMESPACE" - "--pm-job=$PM_JOB" + "--pm-namespace=${PM_NAMESPACE:-default}" + "--pm-job=${PM_JOB:-vllm-podmonitoring}" ) -[[ "$TRAFFIC_SPLIT" ]] && BASE_PYTHON_OPTS+=("--traffic-split=$TRAFFIC_SPLIT") -[[ "$OUTPUT_BUCKET" ]] && BASE_PYTHON_OPTS+=("--output-bucket=$OUTPUT_BUCKET") -[[ "$SCRAPE_SERVER_METRICS" = "true" ]] && BASE_PYTHON_OPTS+=("--scrape-server-metrics") -[[ "$SAVE_AGGREGATED_RESULT" = "true" ]] && BASE_PYTHON_OPTS+=("--save-aggregated-result") -[[ "$STREAM_REQUEST" = "true" ]] && BASE_PYTHON_OPTS+=("--stream-request") -[[ "$IGNORE_EOS" = "true" ]] && BASE_PYTHON_OPTS+=("--ignore-eos") -[[ "$OUTPUT_BUCKET_FILEPATH" ]] && BASE_PYTHON_OPTS+=("--output-bucket-filepath" "$OUTPUT_BUCKET_FILEPATH") -[[ "$TCP_CONN_LIMIT" ]] && BASE_PYTHON_OPTS+=("--tcp-conn-limit" "$TCP_CONN_LIMIT") +[[ "${MIN_INPUT_LENGTH:-}" ]] && BASE_PYTHON_OPTS+=("--min-input-length=$MIN_INPUT_LENGTH") +[[ "${MIN_OUTPUT_LENGTH:-}" ]] && BASE_PYTHON_OPTS+=("--min-output-length=$MIN_OUTPUT_LENGTH") +[[ "${OUTPUT_BUCKET:-}" ]] && BASE_PYTHON_OPTS+=("--output-bucket=$OUTPUT_BUCKET") +[[ "${TRAFFIC_SPLIT:-}" ]] && BASE_PYTHON_OPTS+=("--traffic-split=$TRAFFIC_SPLIT") +[[ "${SCRAPE_SERVER_METRICS:-}" == "true" ]] && BASE_PYTHON_OPTS+=("--scrape-server-metrics") +[[ "${SAVE_AGGREGATED_RESULT:-}" == "true" ]] && BASE_PYTHON_OPTS+=("--save-aggregated-result") +[[ "${STREAM_REQUEST:-}" == "true" ]] && BASE_PYTHON_OPTS+=("--stream-request") +[[ "${IGNORE_EOS:-}" == "true" ]] && BASE_PYTHON_OPTS+=("--ignore-eos") +[[ "${OUTPUT_BUCKET_FILEPATH:-}" ]] && BASE_PYTHON_OPTS+=("--output-bucket-filepath" "$OUTPUT_BUCKET_FILEPATH") +[[ "${TCP_CONN_LIMIT:-}" ]] && BASE_PYTHON_OPTS+=("--tcp-conn-limit" "$TCP_CONN_LIMIT") +[[ "${SPANNER_INSTANCE_ID:-}" ]] && BASE_PYTHON_OPTS+=("--spanner-instance-id" "$SPANNER_INSTANCE_ID") +[[ "${SPANNER_DATABASE_ID:-}" ]] && BASE_PYTHON_OPTS+=("--spanner-database-id" "$SPANNER_DATABASE_ID") + +# Support TARGETMODELS or TARGET_MODELS +TARGETMODELS_EFFECTIVE="${TARGETMODELS:-${TARGET_MODELS:-}}" +[[ "$TARGETMODELS_EFFECTIVE" ]] && BASE_PYTHON_OPTS+=("--targetmodels" "$TARGETMODELS_EFFECTIVE") SLEEP_TIME=${SLEEP_TIME:-0} POST_BENCHMARK_SLEEP_TIME=${POST_BENCHMARK_SLEEP_TIME:-infinity} -for request_rate in $(echo $REQUEST_RATES | tr ',' ' '); do +if [[ -z "${REQUEST_RATES:-}" ]]; then + echo "ERROR: REQUEST_RATES is empty"; exit 1 +fi + +for request_rate in $(echo "$REQUEST_RATES" | tr ',' ' '); do echo "Benchmarking request rate: ${request_rate}" - # TODO: Check if profile already exists, if so then skip timestamp=$(date +"%Y-%m-%d_%H-%M-%S") output_file="latency-profile-${timestamp}.txt" - - if [ "$request_rate" == "0" ]; then + + if [[ "$request_rate" == "0" ]]; then request_rate="inf" - num_prompts=$MAX_NUM_PROMPTS + num_prompts="${MAX_NUM_PROMPTS:?Set MAX_NUM_PROMPTS when REQUEST_RATE=0}" else - num_prompts=$(awk "BEGIN {print int($request_rate * $BENCHMARK_TIME_SECONDS)}") + num_prompts=$(awk "BEGIN {print int($request_rate * ${BENCHMARK_TIME_SECONDS:-60})}") fi echo "TOTAL prompts: $num_prompts" PYTHON_OPTS=("${BASE_PYTHON_OPTS[@]}" "--request-rate=$request_rate" "--num-prompts=$num_prompts") - + $PYTHON "${PYTHON_OPTS[@]}" > "$output_file" cat "$output_file" echo "Sleeping for $SLEEP_TIME seconds..." - sleep $SLEEP_TIME + sleep "$SLEEP_TIME" done export LPG_FINISHED="true" -sleep $POST_BENCHMARK_SLEEP_TIME +sleep "$POST_BENCHMARK_SLEEP_TIME" diff --git a/requirements.txt b/requirements.txt index 1a9d7f7..c28f1b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,10 @@ aioprometheus[starlette] pynvml == 11.5.0 accelerate aiohttp + +# For Google Cloud Storage google-auth google-cloud-storage >= 2.18.2 prometheus_client >= 0.21.0 +google-cloud-spanner +google-api-core \ No newline at end of file